-
Notifications
You must be signed in to change notification settings - Fork 30.6k
Add Code World Model (CWM) #41199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Code World Model (CWM) #41199
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, left some comments to clean up. Btw, do we have converted weights already which we can use for the integration tests?
config = self.model_tester.get_config() | ||
model = CwmModel(config) | ||
model.to(torch_device) | ||
model.eval() | ||
|
||
# input longer than sliding window | ||
seq_length = config.sliding_window + 10 | ||
input_ids = torch.randint(0, config.vocab_size, (1, seq_length), device=torch_device) | ||
|
||
with torch.no_grad(): | ||
outputs = model(input_ids) | ||
|
||
self.assertEqual(outputs.last_hidden_state.shape, (1, seq_length, config.hidden_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better if we can make an integration test and check that the generated ids are correct
# no errors | ||
self.assertIsNotNone(outputs.last_hidden_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not clear what we are testing for here, last hidden state can never be None
, no?
def tearDown(self): | ||
cleanup(torch_device, gc_collect=True) | ||
|
||
def test_cwm_small_model_forward(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use actual model for all below tests and check generated token ids or logit values for a few important cases (sliding window, simple generation etc)
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, cwm |
Adds the Code World Model (CWM) - https://ai.meta.com/research/publications/cwm-an-open-weights-llm-for-research-on-code-generation-with-world-models/
High-level implementation details:
The model repos are:
Note that for VLLM compatibility, model
config.json
still refer toLlama3ForCausalLM
and allama
model_type
— see example. vllm-project/vllm#25611 adds support mappingCwmForCausalLM
to the Llama3 model class in VLLM since VLLM supportsLlama3
+layer_types
with local/global attention - see docs. The model type in theconfig.json
will be updated on HF (and the special automapping condition removed) once this PR is merged and a Transformers release has happened containing theCwmForCausalLM
model class.@ArthurZucker, @zucchini-nlp
Supersedes #41188 due to some fork misery