Skip to content

fix qwen2.5 vl training #1314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 29 additions & 26 deletions paddlemix/examples/qwen2_5_vl/qwen2_5_vl_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
import traceback
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, Any
from typing import Any, Dict, Optional, Sequence

import numpy as np
import paddle
Expand All @@ -31,15 +31,20 @@
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, set_seed
from paddlenlp.trainer.trainer import Trainer
from paddlenlp.trainer.trainer_utils import get_last_checkpoint
from paddlenlp.transformers.processing_utils import ProcessorMixin
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError

from paddlemix.datasets.internvl_dataset import ConcatDataset, WeightedConcatDataset
from paddlemix.models.qwen2_5_vl import MIXQwen2_5_Tokenizer
from paddlemix.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from paddlemix.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from paddlemix.models.qwen2_5_vl.supervised import _encode_supervised_example
from paddlemix.models.qwen2_5_vl.template import TEMPLATES
from paddlemix.processors.qwen2_5_vl_processing import Qwen2_5_VLImageProcessor, Qwen2_5_VLProcessor
from paddlenlp.transformers.processing_utils import ProcessorMixin
from paddlemix.processors.qwen2_5_vl_processing import (
Qwen2_5_VLImageProcessor,
Qwen2_5_VLProcessor,
)

Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
Expand Down Expand Up @@ -291,9 +296,6 @@ def get_transform(self):
return self.processor.image_processor

def multi_modal_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()

# Ensure the first conversation contains an image placeholder
if "<image>" not in data_item["messages"][0]["content"]:
data_item["messages"][0]["content"] = "<image>\n" + data_item["messages"][0]["content"]
Expand Down Expand Up @@ -352,7 +354,7 @@ def pure_text_get_item(self, data_item):
attention_mask=attention_mask,
images=[],
)

return ret

def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
Expand Down Expand Up @@ -457,7 +459,7 @@ def __post_init__(self):

def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []

for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
Expand All @@ -467,9 +469,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"])

if (
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
):
if self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0:
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
Expand All @@ -480,12 +480,16 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens

if len(fake_input_ids) != 0:
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"]+ fake_input_ids["input_ids"]
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids["input_ids"])
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids["input_ids"]
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(
fake_input_ids["input_ids"]
)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids["input_ids"])
else:
features[0]["input_ids"] = fake_input_ids["input_ids"] + features[0]["input_ids"]
features[0]["attention_mask"] = [0] * len(fake_input_ids["input_ids"]) + features[0]["attention_mask"]
features[0]["attention_mask"] = [0] * len(fake_input_ids["input_ids"]) + features[0][
"attention_mask"
]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids["input_ids"]) + features[0]["labels"]

batch_images = fake_images
Expand All @@ -499,7 +503,6 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
token_type_ids = mm_inputs.pop("token_type_ids")
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]

features: Dict[str, "paddle.Tensor"] = super().__call__(features)

if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2_5_vl mrope
Expand All @@ -514,27 +517,25 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens

features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)



if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)
orig_len = cross_attention_mask.size(1)
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
seq_len = features["input_ids"].shape[1]
orig_len = cross_attention_mask.shape[1]
mm_inputs["cross_attention_mask"] = paddle.nn.functional.pad(
cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len)
)

features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()

if "image_bound" in features: # for minicpmv inputs
bsz, seq_length = features["input_ids"].shape
features["position_ids"] = paddle.arange(seq_length).long().repeat(bsz, 1)
features["position_ids"] = paddle.arange(seq_length).long().tile([bsz, 1])
return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}

return features



def main():
parser = PdArgumentParser((ModelArguments, DataTrainingArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
Expand Down Expand Up @@ -592,10 +593,12 @@ def main():
print(f"Loading Tokenizer: {tokenizer_path}")

MODEL_NAME = model_args.model_name_or_path
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype=dtype, attn_implementation="flash_attention_2")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype=dtype, attn_implementation="sdpa")
image_processor = Qwen2_5_VLImageProcessor()
tokenizer = MIXQwen2_5_Tokenizer.from_pretrained(MODEL_NAME, padding_side="right")
processor = Qwen2_5_VLProcessor(image_processor, tokenizer)
processor = Qwen2_5_VLProcessor(
image_processor, tokenizer, image_max_pixels=model_args.image_resolution**2, image_min_pixels=1024
)

tokenizer.tokenizer_path = tokenizer_path
tokenizer.model_max_length = data_args.max_seq_length
Expand Down
60 changes: 27 additions & 33 deletions paddlemix/models/qwen2_5_vl/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,27 @@
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Dict, Optional, Sequence, Set, Union

from typing_extensions import override

from .tool_utils import get_tool_utils
from .tool_utils import FunctionCall, get_tool_utils

SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]


if TYPE_CHECKING:
from .tool_utils import FunctionCall


@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[str] = None

@abstractmethod
def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
r"""Forms a list of slots according to the inputs to encode."""
...

def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract a list of tuples from the response message if using tools.

Each tuple consists of function name and function arguments.
"""
Expand Down Expand Up @@ -84,50 +77,51 @@ def apply(self, **kwargs) -> SLOTS:
if isinstance(slot, str):
for name, value in kwargs.items():
if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value))
raise RuntimeError(f"Expected a string, got {value}")

slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")

return elements


@dataclass
class FunctionFormatter(Formatter):
class FunctionFormatter(StringFormatter):
def __post_init__(self):
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
super().__post_init__()
self.tool_utils = get_tool_utils(self.tool_format)

@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
content: str = kwargs.pop("content")
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
thought = re.search(regex, content)
if thought:
content = content.replace(thought.group(0), "")

functions: list[FunctionCall] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]

for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
functions.append(
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)

except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string

elements = []
for name, arguments in functions:
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
function_str = self.tool_utils.function_formatter(functions)
if thought:
function_str = thought.group(0) + function_str

return elements
return super().apply(content=function_str)


@dataclass
Expand All @@ -142,8 +136,8 @@ def apply(self, **kwargs) -> SLOTS:
tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string

@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)
11 changes: 6 additions & 5 deletions paddlemix/models/qwen2_5_vl/mm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from PIL import Image
from PIL.Image import Image as ImageObject
from typing_extensions import override

from paddlemix.processors.processing_utils import BaseImageProcessor

IGNORE_INDEX = -100
Expand Down Expand Up @@ -70,12 +71,12 @@ def _preprocess_image(
if (image.width * image.height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
image = image.resize((width, height))

if (image.width * image.height) < image_min_pixels:
resize_factor = math.sqrt(image_min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
image = image.resize((width, height))

if image.mode != "RGB":
image = image.convert("RGB")
Expand Down Expand Up @@ -168,13 +169,12 @@ def _get_mm_inputs(
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
image_max_pixels=getattr(processor, "image_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "image_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
input_dict["videos"] = videos

mm_inputs = {}
if image_processor != video_processor:
if input_dict.get("images") is not None:
Expand Down Expand Up @@ -337,6 +337,7 @@ def get_mm_inputs(
"qwen2_5_vl": Qwen2_5_vlPlugin,
}


def get_mm_plugin(
name: str,
image_token: Optional[str] = None,
Expand Down
Loading