Skip to content

Conversation

jacobkahn
Copy link
Contributor

Adds the Code World Model (CWM) - https://ai.meta.com/research/publications/cwm-an-open-weights-llm-for-research-on-code-generation-with-world-models/

High-level implementation details:

  • This is a GQA + local/global sliding window attention model
  • Implemented in HF Llama3 + interleaved sliding window attention
  • Inheriting from Gemma2/3 requires weight remapping which breaks VLLM compatibility and other components, so this is implemented using the existing causal mask utils from HF

The model repos are:

Note that for VLLM compatibility, model config.json still refer to Llama3ForCausalLM and a llama model_type — see example. vllm-project/vllm#25611 adds support mapping CwmForCausalLM to the Llama3 model class in VLLM since VLLM supports Llama3 + layer_types with local/global attention - see docs. The model type in the config.json will be updated on HF (and the special automapping condition removed) once this PR is merged and a Transformers release has happened containing the CwmForCausalLM model class.

@ArthurZucker, @zucchini-nlp

Supersedes #41188 due to some fork misery

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left some comments to clean up. Btw, do we have converted weights already which we can use for the integration tests?

Comment on lines +144 to +156
config = self.model_tester.get_config()
model = CwmModel(config)
model.to(torch_device)
model.eval()

# input longer than sliding window
seq_length = config.sliding_window + 10
input_ids = torch.randint(0, config.vocab_size, (1, seq_length), device=torch_device)

with torch.no_grad():
outputs = model(input_ids)

self.assertEqual(outputs.last_hidden_state.shape, (1, seq_length, config.hidden_size))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better if we can make an integration test and check that the generated ids are correct

Comment on lines +184 to +185
# no errors
self.assertIsNotNone(outputs.last_hidden_state)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not clear what we are testing for here, last hidden state can never be None, no?

def tearDown(self):
cleanup(torch_device, gc_collect=True)

def test_cwm_small_model_forward(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use actual model for all below tests and check generated token ids or logit values for a few important cases (sliding window, simple generation etc)

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, cwm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants