diff --git a/paddlemix/examples/qwen2_5_vl/qwen2_5_vl_finetune.py b/paddlemix/examples/qwen2_5_vl/qwen2_5_vl_finetune.py index 445b55dd5..c6f80b399 100644 --- a/paddlemix/examples/qwen2_5_vl/qwen2_5_vl_finetune.py +++ b/paddlemix/examples/qwen2_5_vl/qwen2_5_vl_finetune.py @@ -20,7 +20,7 @@ import sys import traceback from dataclasses import dataclass, field -from typing import Dict, Optional, Sequence, Any +from typing import Any, Dict, Optional, Sequence import numpy as np import paddle @@ -31,15 +31,20 @@ from paddlenlp.trainer import PdArgumentParser, TrainingArguments, set_seed from paddlenlp.trainer.trainer import Trainer from paddlenlp.trainer.trainer_utils import get_last_checkpoint +from paddlenlp.transformers.processing_utils import ProcessorMixin from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError from paddlemix.datasets.internvl_dataset import ConcatDataset, WeightedConcatDataset from paddlemix.models.qwen2_5_vl import MIXQwen2_5_Tokenizer -from paddlemix.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration +from paddlemix.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, +) from paddlemix.models.qwen2_5_vl.supervised import _encode_supervised_example from paddlemix.models.qwen2_5_vl.template import TEMPLATES -from paddlemix.processors.qwen2_5_vl_processing import Qwen2_5_VLImageProcessor, Qwen2_5_VLProcessor -from paddlenlp.transformers.processing_utils import ProcessorMixin +from paddlemix.processors.qwen2_5_vl_processing import ( + Qwen2_5_VLImageProcessor, + Qwen2_5_VLProcessor, +) Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -291,9 +296,6 @@ def get_transform(self): return self.processor.image_processor def multi_modal_get_item(self, data_item): - # Build transformation function - transform = self.get_transform() - # Ensure the first conversation contains an image placeholder if "" not in data_item["messages"][0]["content"]: data_item["messages"][0]["content"] = "\n" + data_item["messages"][0]["content"] @@ -352,7 +354,7 @@ def pure_text_get_item(self, data_item): attention_mask=attention_mask, images=[], ) - + return ret def __getitem__(self, i) -> Dict[str, paddle.Tensor]: @@ -457,7 +459,7 @@ def __post_init__(self): def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tensor"]: batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], [] - + for feature in features: images = feature.pop("images", None) or [] videos = feature.pop("videos", None) or [] @@ -467,9 +469,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens batch_vidlens.append(len(videos)) batch_input_ids.append(feature["input_ids"]) - if ( - self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 - ): + if self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0: fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) @@ -480,12 +480,16 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens if len(fake_input_ids) != 0: if self.tokenizer.padding_side == "right": - features[0]["input_ids"] = features[0]["input_ids"]+ fake_input_ids["input_ids"] - features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids["input_ids"]) + features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids["input_ids"] + features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len( + fake_input_ids["input_ids"] + ) features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids["input_ids"]) else: features[0]["input_ids"] = fake_input_ids["input_ids"] + features[0]["input_ids"] - features[0]["attention_mask"] = [0] * len(fake_input_ids["input_ids"]) + features[0]["attention_mask"] + features[0]["attention_mask"] = [0] * len(fake_input_ids["input_ids"]) + features[0][ + "attention_mask" + ] features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids["input_ids"]) + features[0]["labels"] batch_images = fake_images @@ -499,7 +503,6 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens token_type_ids = mm_inputs.pop("token_type_ids") for i, feature in enumerate(features): feature["token_type_ids"] = token_type_ids[i] - features: Dict[str, "paddle.Tensor"] = super().__call__(features) if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2_5_vl mrope @@ -514,13 +517,13 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs) - - if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled cross_attention_mask = mm_inputs.pop("cross_attention_mask") - seq_len = features["input_ids"].size(1) - orig_len = cross_attention_mask.size(1) - mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len)) + seq_len = features["input_ids"].shape[1] + orig_len = cross_attention_mask.shape[1] + mm_inputs["cross_attention_mask"] = paddle.nn.functional.pad( + cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len) + ) features.update(mm_inputs) if isinstance(features.get("pixel_values"), list): # for pixtral inputs @@ -528,13 +531,11 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens if "image_bound" in features: # for minicpmv inputs bsz, seq_length = features["input_ids"].shape - features["position_ids"] = paddle.arange(seq_length).long().repeat(bsz, 1) + features["position_ids"] = paddle.arange(seq_length).long().tile([bsz, 1]) return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]} - return features - def main(): parser = PdArgumentParser((ModelArguments, DataTrainingArguments, PreTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): @@ -592,10 +593,12 @@ def main(): print(f"Loading Tokenizer: {tokenizer_path}") MODEL_NAME = model_args.model_name_or_path - model = Qwen2_5_VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype=dtype, attn_implementation="flash_attention_2") + model = Qwen2_5_VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype=dtype, attn_implementation="sdpa") image_processor = Qwen2_5_VLImageProcessor() tokenizer = MIXQwen2_5_Tokenizer.from_pretrained(MODEL_NAME, padding_side="right") - processor = Qwen2_5_VLProcessor(image_processor, tokenizer) + processor = Qwen2_5_VLProcessor( + image_processor, tokenizer, image_max_pixels=model_args.image_resolution**2, image_min_pixels=1024 + ) tokenizer.tokenizer_path = tokenizer_path tokenizer.model_max_length = data_args.max_seq_length diff --git a/paddlemix/models/qwen2_5_vl/formatter.py b/paddlemix/models/qwen2_5_vl/formatter.py index 31fefd902..a841b9f53 100644 --- a/paddlemix/models/qwen2_5_vl/formatter.py +++ b/paddlemix/models/qwen2_5_vl/formatter.py @@ -16,19 +16,15 @@ import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Dict, Optional, Sequence, Set, Union from typing_extensions import override -from .tool_utils import get_tool_utils +from .tool_utils import FunctionCall, get_tool_utils SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] -if TYPE_CHECKING: - from .tool_utils import FunctionCall - - @dataclass class Formatter(ABC): slots: SLOTS = field(default_factory=list) @@ -36,14 +32,11 @@ class Formatter(ABC): @abstractmethod def apply(self, **kwargs) -> SLOTS: - r""" - Forms a list of slots according to the inputs to encode. - """ + r"""Forms a list of slots according to the inputs to encode.""" ... - def extract(self, content: str) -> Union[str, List["FunctionCall"]]: - r""" - Extract a list of tuples from the response message if using tools. + def extract(self, content: str) -> Union[str, list["FunctionCall"]]: + r"""Extract a list of tuples from the response message if using tools. Each tuple consists of function name and function arguments. """ @@ -84,50 +77,51 @@ def apply(self, **kwargs) -> SLOTS: if isinstance(slot, str): for name, value in kwargs.items(): if not isinstance(value, str): - raise RuntimeError("Expected a string, got {}".format(value)) + raise RuntimeError(f"Expected a string, got {value}") slot = slot.replace("{{" + name + "}}", value, 1) elements.append(slot) elif isinstance(slot, (dict, set)): elements.append(slot) else: - raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.") return elements @dataclass -class FunctionFormatter(Formatter): +class FunctionFormatter(StringFormatter): def __post_init__(self): - self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots + super().__post_init__() + self.tool_utils = get_tool_utils(self.tool_format) @override def apply(self, **kwargs) -> SLOTS: - content = kwargs.pop("content") - functions: List[Tuple[str, str]] = [] + content: str = kwargs.pop("content") + regex = re.compile(r"(.*)", re.DOTALL) + thought = re.search(regex, content) + if thought: + content = content.replace(thought.group(0), "") + + functions: list[FunctionCall] = [] try: tool_calls = json.loads(content) if not isinstance(tool_calls, list): # parallel function call tool_calls = [tool_calls] for tool_call in tool_calls: - functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + functions.append( + FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)) + ) except json.JSONDecodeError: - raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string + raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string - elements = [] - for name, arguments in functions: - for slot in self.slots: - if isinstance(slot, str): - slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) - elements.append(slot) - elif isinstance(slot, (dict, set)): - elements.append(slot) - else: - raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + function_str = self.tool_utils.function_formatter(functions) + if thought: + function_str = thought.group(0) + function_str - return elements + return super().apply(content=function_str) @dataclass @@ -142,8 +136,8 @@ def apply(self, **kwargs) -> SLOTS: tools = json.loads(content) return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] except json.JSONDecodeError: - raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string + raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string @override - def extract(self, content: str) -> Union[str, List["FunctionCall"]]: + def extract(self, content: str) -> Union[str, list["FunctionCall"]]: return self.tool_utils.tool_extractor(content) diff --git a/paddlemix/models/qwen2_5_vl/mm_plugin.py b/paddlemix/models/qwen2_5_vl/mm_plugin.py index cd184f1ed..e2b635dbb 100644 --- a/paddlemix/models/qwen2_5_vl/mm_plugin.py +++ b/paddlemix/models/qwen2_5_vl/mm_plugin.py @@ -26,6 +26,7 @@ from PIL import Image from PIL.Image import Image as ImageObject from typing_extensions import override + from paddlemix.processors.processing_utils import BaseImageProcessor IGNORE_INDEX = -100 @@ -70,12 +71,12 @@ def _preprocess_image( if (image.width * image.height) > image_max_pixels: resize_factor = math.sqrt(image_max_pixels / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) if (image.width * image.height) < image_min_pixels: resize_factor = math.sqrt(image_min_pixels / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) if image.mode != "RGB": image = image.convert("RGB") @@ -168,13 +169,12 @@ def _get_mm_inputs( if len(videos) != 0: videos = self._regularize_videos( videos, - image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), - image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + image_max_pixels=getattr(processor, "image_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "image_min_pixels", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) input_dict["videos"] = videos - mm_inputs = {} if image_processor != video_processor: if input_dict.get("images") is not None: @@ -337,6 +337,7 @@ def get_mm_inputs( "qwen2_5_vl": Qwen2_5_vlPlugin, } + def get_mm_plugin( name: str, image_token: Optional[str] = None, diff --git a/paddlemix/models/qwen2_5_vl/template.py b/paddlemix/models/qwen2_5_vl/template.py index 9d07e1cf7..55a649ff9 100644 --- a/paddlemix/models/qwen2_5_vl/template.py +++ b/paddlemix/models/qwen2_5_vl/template.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from dataclasses import dataclass from enum import Enum, unique from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -46,12 +47,13 @@ class Template: format_function: "Formatter" format_observation: "Formatter" format_tools: "Formatter" - format_separator: "Formatter" format_prefix: "Formatter" default_system: str - stop_words: List[str] + stop_words: list[str] + thought_words: tuple[str, str] efficient_eos: bool replace_eos: bool + replace_jinja_template: bool mm_plugin: "BasePlugin" def encode_oneturn( @@ -91,6 +93,27 @@ def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: """ return self.format_tools.extract(content) + def get_stop_token_ids(self, tokenizer: "PretrainedTokenizer") -> list[int]: + r"""Return stop token ids.""" + stop_token_ids = {tokenizer.eos_token_id} + for token in self.stop_words: + stop_token_ids.add(tokenizer.convert_tokens_to_ids(token)) + + return list(stop_token_ids) + + def add_thought(self, content: str) -> str: + r"""Add empty thought to assistant message.""" + return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content + + def remove_thought(self, content: str) -> str: + r"""Remove thought from assistant message.""" + pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL) + return re.sub(pattern, "", content).lstrip("\n") + + def get_thought_word_ids(self, tokenizer: "PretrainedTokenizer") -> list[int]: + r"""Get the token ids of thought words.""" + return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False) + def _encode( self, tokenizer: "PretrainedTokenizer", @@ -114,16 +137,13 @@ def _encode( tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text)) - if i > 0 and i % 2 == 0: - elements += self.format_separator.apply() - - if message["role"] == Role.USER.value: + if message["role"] == Role.USER: elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) - elif message["role"] == Role.ASSISTANT.value: + elif message["role"] == Role.ASSISTANT: elements += self.format_assistant.apply(content=message["content"]) - elif message["role"] == Role.OBSERVATION.value: + elif message["role"] == Role.OBSERVATION: elements += self.format_observation.apply(content=message["content"]) - elif message["role"] == Role.FUNCTION.value: + elif message["role"] == Role.FUNCTION: elements += self.format_function.apply(content=message["content"]) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -153,6 +173,137 @@ def _convert_elements_to_ids(self, tokenizer: "PretrainedTokenizer", elements: " return token_ids + @staticmethod + def _add_or_replace_eos_token(tokenizer: "PretrainedTokenizer", eos_token: str) -> None: + r"""Add or replace eos token to the tokenizer.""" + if tokenizer.eos_token == eos_token: + return + + is_added = tokenizer.eos_token_id is None + num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) + + if is_added: + print(f"Add eos token: {tokenizer.eos_token}.") + else: + print(f"Replace eos token: {tokenizer.eos_token}.") + + if num_added_tokens > 0: + print("New tokens have been added, make sure `resize_vocab` is True.") + + def fix_special_tokens(self, tokenizer: "PretrainedTokenizer") -> None: + r"""Add eos token and pad token to the tokenizer.""" + stop_words = self.stop_words + if self.replace_eos: + if not stop_words: + raise ValueError("Stop words are required to replace the EOS token.") + + self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) + stop_words = stop_words[1:] + + if tokenizer.eos_token_id is None: + self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + print(f"Add pad token: {tokenizer.pad_token}") + + if stop_words: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False + ) + print("Add {} to stop words.".format(",".join(stop_words))) + if num_added_tokens > 0: + print("New tokens have been added, make sure `resize_vocab` is True.") + + @staticmethod + def _jinja_escape(content: str) -> str: + r"""Escape single quotes in content.""" + return content.replace("'", r"\'") + + @staticmethod + def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PretrainedTokenizer", placeholder: str = "content") -> str: + r"""Convert slots to jinja template.""" + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'") + if len(slot_pieces) > 1: + slot_items.append(placeholder) + if slot_pieces[1]: + slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'") + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append("'" + tokenizer.bos_token + "'") + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append("'" + tokenizer.eos_token + "'") + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return " + ".join(slot_items) + + def _get_jinja_template(self, tokenizer: "PretrainedTokenizer") -> str: + r"""Return the jinja template.""" + prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) + system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message") + user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) + assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) + jinja_template = "" + if prefix: + jinja_template += "{{ " + prefix + " }}" + + if self.default_system: + jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" + "{% if system_message is defined %}{{ " + system + " }}{% endif %}" + "{% for message in loop_messages %}" + "{% set content = message['content'] %}" + "{% if message['role'] == 'user' %}" + "{{ " + user + " }}" + "{% elif message['role'] == 'assistant' %}" + "{{ " + assistant + " }}" + "{% endif %}" + "{% endfor %}" + ) + return jinja_template + + def fix_jinja_template(self, tokenizer: "PretrainedTokenizer") -> None: + r"""Replace the jinja template in the tokenizer.""" + if tokenizer.chat_template is None or self.replace_jinja_template: + try: + tokenizer.chat_template = self._get_jinja_template(tokenizer) + except ValueError as e: + print(f"Cannot add this chat template to tokenizer: {e}.") + + @staticmethod + def _convert_slots_to_ollama( + slots: "SLOTS", tokenizer: "PretrainedTokenizer", placeholder: str = "content" + ) -> str: + r"""Convert slots to ollama template.""" + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append(slot_pieces[0]) + if len(slot_pieces) > 1: + slot_items.append("{{ " + placeholder + " }}") + if slot_pieces[1]: + slot_items.append(slot_pieces[1]) + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append(tokenizer.bos_token) + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append(tokenizer.eos_token) + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return "".join(slot_items) + TEMPLATES: Dict[str, "Template"] = {} @@ -165,13 +316,15 @@ def _register_template( format_function: Optional["Formatter"] = None, format_observation: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None, - format_separator: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None, default_system: str = "", - stop_words: Sequence[str] = [], + stop_words: Optional[list[str]] = None, + thought_words: Optional[tuple[str, str]] = None, efficient_eos: bool = False, replace_eos: bool = False, + replace_jinja_template: bool = False, mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), + template_class: type["Template"] = Template, ) -> None: r""" Registers a chat template. @@ -199,13 +352,13 @@ def _register_template( ) ``` """ - eos_slots = [] if efficient_eos else [{"eos_token"}] template_class = Template + + default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] default_user_formatter = StringFormatter(slots=["{{content}}"]) - default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) - default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default") + default_assistant_formatter = StringFormatter(slots=default_slots) + default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default") - default_separator_formatter = EmptyFormatter() default_prefix_formatter = EmptyFormatter() TEMPLATES[name] = template_class( format_user=format_user or default_user_formatter, @@ -214,12 +367,13 @@ def _register_template( format_function=format_function or default_function_formatter, format_observation=format_observation or format_user or default_user_formatter, format_tools=format_tools or default_tool_formatter, - format_separator=format_separator or default_separator_formatter, format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, - stop_words=stop_words, + stop_words=stop_words or [], + thought_words=thought_words or ("", ""), efficient_eos=efficient_eos, replace_eos=replace_eos, + replace_jinja_template=replace_jinja_template, mm_plugin=mm_plugin, ) @@ -227,9 +381,13 @@ def _register_template( _register_template( name="qwen2_5_vl", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_separator=EmptyFormatter(slots=["\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], replace_eos=True, diff --git a/paddlemix/models/qwen2_5_vl/tool_utils.py b/paddlemix/models/qwen2_5_vl/tool_utils.py index b2308335e..29cb88833 100644 --- a/paddlemix/models/qwen2_5_vl/tool_utils.py +++ b/paddlemix/models/qwen2_5_vl/tool_utils.py @@ -34,6 +34,13 @@ "```\n" ) +QWEN_TOOL_PROMPT = ( + "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n{tool_text}" + "\n\n\nFor each function call, return a json object with function name and arguments within " + """ XML tags:\n\n{{"name": , """ + """"arguments": }}\n""" +) FunctionCall = namedtuple("FunctionCall", ["name", "arguments"]) @@ -44,13 +51,13 @@ class ToolUtils(ABC): Base class for tool utilities. """ - @staticmethod - @abstractmethod - def get_function_slots() -> SLOTS: - r""" - Gets a list of slots corresponding to a single function call. - """ - ... + # @staticmethod + # @abstractmethod + # def get_function_slots() -> SLOTS: + # r""" + # Gets a list of slots corresponding to a single function call. + # """ + # ... @staticmethod @abstractmethod @@ -130,9 +137,54 @@ def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: return results -TOOLS = { - "default": DefaultToolUtils(), -} +class QwenToolUtils(ToolUtils): + r"""Qwen 2.5 tool using template.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool} + tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) + + return QWEN_TOOL_PROMPT.format(tool_text=tool_text) + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + function_texts = [] + for name, arguments in functions: + function_texts.append( + "\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n" + ) + + return "\n".join(function_texts) + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: + regex = re.compile(r"(.+?)(?=\s*|\s*$)", re.DOTALL) + tool_match: list[str] = re.findall(regex, content) + if not tool_match: + return content + + results = [] + for tool in tool_match: + try: + tool = json.loads(tool.strip()) + except json.JSONDecodeError: + return content + + if "name" not in tool or "arguments" not in tool: + return content + + results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False))) + + return results + + +TOOLS = {"default": DefaultToolUtils(), "qwen": QwenToolUtils()} def get_tool_utils(name: str) -> "ToolUtils": diff --git a/paddlemix/processors/qwen2_5_vl_processing.py b/paddlemix/processors/qwen2_5_vl_processing.py index 37c2abb3c..59fc1d48c 100644 --- a/paddlemix/processors/qwen2_5_vl_processing.py +++ b/paddlemix/processors/qwen2_5_vl_processing.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import base64 import math import os @@ -80,6 +94,7 @@ "Qwen2_5_VLImageProcessor", ] + def is_scaled_image(image: np.ndarray) -> bool: """ Checks to see whether the pixel values have already been rescaled to [0, 1]. @@ -104,16 +119,19 @@ class Qwen2_5_VLProcessor(ProcessorMixin): chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. """ + attributes = ["image_processor", "tokenizer"] image_processor_class = "Qwen2_5_VLImageProcessor" - tokenizer_class = 'MIXQwen2_5_Tokenizer' + tokenizer_class = "MIXQwen2_5_Tokenizer" # , 'Qwen2TokenizerFast' def __init__(self, image_processor, text_processor, **kwargs): super().__init__(image_processor, text_processor) - self.image_processor.min_pixels = kwargs.get("min_pixels", 3136) - self.image_processor.max_pixels = kwargs.get("max_pixels", 12845056) - # self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + + # qwen2.5-vl training (liaojincheng) image_min_pixels is used for template get mm_input + self.image_min_pixels = kwargs.get("image_min_pixels", self.image_processor.min_pixels) + self.image_max_pixels = kwargs.get("image_max_pixels", self.image_processor.max_pixels) + # self.image_token = "" if not hasattr(tokenizer, "image_token") else tokenizer.image_token # self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token def __call__( @@ -171,19 +189,17 @@ def __call__( image_grid_thw = None if videos is not None: videos_inputs = self.image_processor(images=None, videos=videos, return_tensors=return_tensors) - video_grid_thw = videos_inputs['video_grid_thw'] - fps = videos_inputs.pop('fps', 2.0) + video_grid_thw = videos_inputs["video_grid_thw"] + fps = videos_inputs.pop("fps", 2.0) if isinstance(fps, (int, float)): - second_per_grid_ts = [self.image_processor. - temporal_patch_size / fps] * len(video_grid_thw) - elif hasattr(fps, '__len__') and len(fps) == len(video_grid_thw): - second_per_grid_ts = [(self.image_processor. - temporal_patch_size / tmp) for tmp in fps] + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [(self.image_processor.temporal_patch_size / tmp) for tmp in fps] else: raise ValueError( f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." - ) - videos_inputs.update({'second_per_grid_ts': second_per_grid_ts}) + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) else: videos_inputs = {} video_grid_thw = None @@ -237,7 +253,6 @@ def model_input_names(self): image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) - def post_process_image_text_to_text(self, generated_outputs): """ Post-process the output of the model to decode the text. @@ -250,17 +265,9 @@ def post_process_image_text_to_text(self, generated_outputs): Returns: `List[str]`: The decoded text. """ - return self.tokenizer.batch_decode(generated_outputs, - skip_special_tokens=True, clean_up_tokenization_spaces=False) - - @property - def model_input_names(self): - tokenizer_input_names = self.tokenizer.model_input_names - image_processor_input_names = self.image_processor.model_input_names - return list(dict.fromkeys(tokenizer_input_names + - image_processor_input_names)) - - + return self.tokenizer.batch_decode( + generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) def make_batched_images(images) -> List[List[ImageInput]]: @@ -302,6 +309,7 @@ def make_batched_videos(videos) -> List[VideoInput]: raise ValueError(f"Could not make batched video from {videos}") + class Qwen2_5_VLImageProcessor(BaseImageProcessor): """ Constructs a Qwen2.5-VL image processor that dynamically resizes images based on the original images. @@ -335,7 +343,13 @@ class Qwen2_5_VLImageProcessor(BaseImageProcessor): The merge size of the vision encoder to llm encoder. """ - model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + model_input_names = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] def __init__( self, @@ -347,8 +361,10 @@ def __init__( image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = True, + size: Dict[str, int] = None, min_pixels: int = 56 * 56, - max_pixels: int = 28 * 28 * 1280, + # max_pixels: int = 28 * 28 * 1280, + max_pixels: int = 12845056, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2, @@ -362,12 +378,24 @@ def __init__( self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD - self.min_pixels = min_pixels - self.max_pixels = max_pixels self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size - self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} + + # liaojincheng fix qwen2.5-vl image processor + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + else: + size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280} + # backward compatibility: override size with min_pixels and max_pixels if they are provided + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels + self.min_pixels = size["shortest_edge"] + self.max_pixels = size["longest_edge"] + self.size = size + self.do_convert_rgb = do_convert_rgb def _preprocess( @@ -441,7 +469,6 @@ def _preprocess( processed_images = [] for image in images: - if do_resize: resized_height, resized_width = smart_resize( height, @@ -450,9 +477,12 @@ def _preprocess( min_pixels=self.min_pixels, max_pixels=self.max_pixels, ) - image = image.astype('uint8') #TODO : 需要手动加上,否则多除255 导致结果会出错 + image = image.astype("uint8") # TODO : 需要手动加上,否则多除255 导致结果会出错 image = resize( - image, size=(resized_height, resized_width), resample=resample, data_format=input_data_format, + image, + size=(resized_height, resized_width), + resample=resample, + data_format=input_data_format, ) if do_rescale: @@ -498,6 +528,8 @@ def preprocess( videos: VideoInput = None, do_resize: bool = None, size: Dict[str, int] = None, + min_pixels: int = None, + max_pixels: int = None, resample: PILImageResampling = None, do_rescale: bool = None, rescale_factor: float = None, @@ -556,6 +588,18 @@ def preprocess( - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ + if size is not None: + if "shortest_edge" not in size or "longest_edge" not in size: + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + min_pixels = size["shortest_edge"] + else: + size = self.size + # backward compatibility: override size with min_pixels and max_pixels if they are provided + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels + do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size resample = resample if resample is not None else self.resample @@ -638,31 +682,33 @@ def floor_by_factor(number: int, factor: int) -> int: def smart_resize( - height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS -) -> Tuple[int, int]: - """ - Rescales the image so that the following conditions are met: + height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 +): + """Rescales the image so that the following conditions are met: 1. Both dimensions (height and width) are divisible by 'factor'. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 3. The aspect ratio of the image is maintained as closely as possible. + """ - if max(height, width) / min(height, width) > MAX_RATIO: + if height < factor or width < factor: + raise ValueError(f"height:{height} and width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 200: raise ValueError( - f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) - h_bar = floor_by_factor(height / beta, factor) - w_bar = floor_by_factor(width / beta, factor) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar @@ -751,6 +797,7 @@ def smart_nframes( def is_decord_available() -> bool: import importlib.util + return importlib.util.find_spec("decord") is not None @@ -758,15 +805,16 @@ def _read_video_decord( ele: dict, ) -> paddle.Tensor: import decord + video_path = ele["video"] st = time.time() vr = decord.VideoReader(video_path) - if 'video_start' in ele or 'video_end' in ele: + if "video_start" in ele or "video_end" in ele: raise NotImplementedError("not support start_pts and end_pts in decord for now.") total_frames, video_fps = len(vr), vr.get_avg_fps() logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) - idx = paddle.linspace(0, total_frames - 1, nframes).round().astype('int64') + idx = paddle.linspace(0, total_frames - 1, nframes).round().astype("int64") idx = paddle.clip(idx, 0, total_frames - 1).tolist() video = vr.get_batch(idx).asnumpy() video = paddle.to_tensor(video).transpose([0, 3, 1, 2]) # Convert to TCHW format @@ -793,65 +841,65 @@ def get_video_reader_backend() -> str: return video_reader_backend -def custom_resize(video, size, interpolation='bicubic', antialias=True): +def custom_resize(video, size, interpolation="bicubic", antialias=True): """ Custom resize function for PaddlePaddle to mimic PyTorch's functionality. - + Args: video (paddle.Tensor): Input video tensor of shape [T, C, H, W] size (list[int]): Target size [H, W] interpolation (str): Interpolation method, default is 'bicubic' antialias (bool): Whether to use anti-aliasing, default is True - + Returns: paddle.Tensor: Resized video tensor """ # 确保输入是4D张量 [T, C, H, W] if video.ndim != 4: raise ValueError(f"Expected 4D tensor, got {video.ndim}D tensor") - + # 转换为浮点类型 - video = video.astype('float32') - + video = video.astype("float32") + # 获取原始尺寸 T, C, H, W = video.shape - + # 设置插值模式 - if interpolation == 'bicubic': - mode = 'bicubic' - elif interpolation == 'bilinear': - mode = 'bilinear' - elif interpolation == 'nearest': - mode = 'nearest' + if interpolation == "bicubic": + mode = "bicubic" + elif interpolation == "bilinear": + mode = "bilinear" + elif interpolation == "nearest": + mode = "nearest" else: raise ValueError(f"Unsupported interpolation mode: {interpolation}") - + # 重塑张量以便于处理 video = video.reshape([-1, C, H, W]) - + # 执行resize操作 - if antialias and mode in ['bicubic', 'bilinear']: + if antialias and mode in ["bicubic", "bilinear"]: # PaddlePaddle目前没有直接支持antialias的选项,我们可以通过先下采样再上采样来模拟 if H > size[0] or W > size[1]: # 下采样 - scale_factor = min(size[0]/H, size[1]/W, 1) + scale_factor = min(size[0] / H, size[1] / W, 1) if scale_factor < 1: video = F.interpolate(video, scale_factor=scale_factor, mode=mode, align_corners=False) # 上采样到目标尺寸 video = F.interpolate(video, size=size, mode=mode, align_corners=False) else: video = F.interpolate(video, size=size, mode=mode, align_corners=False) - + # 恢复原始形状 video = video.reshape([T, C, size[0], size[1]]) - + return video def gaussian_kernel_1d(size, sigma): """生成1D高斯核""" x = np.arange(-(size // 2), size // 2 + 1) - kernel = np.exp(-x**2 / (2 * sigma**2)) + kernel = np.exp(-(x**2) / (2 * sigma**2)) return kernel / kernel.sum() @@ -881,11 +929,11 @@ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> Union[paddle.Ten max_pixels=max_pixels, ) video = F.interpolate( - video.astype('float32'), - size=[resized_height, resized_width], - mode='bicubic', + video.astype("float32"), + size=[resized_height, resized_width], + mode="bicubic", align_corners=False, - data_format='NCHW' + data_format="NCHW", ) video = paddle.clip(video, 0, 255)