Skip to content

Add fast image processor Janus, Deepseek VL, Deepseek VL hybrid #39739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/deepseek_vl.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ model = DeepseekVLForConditionalGeneration.from_pretrained(

[[autodoc]] DeepseekVLImageProcessor

## DeepseekVLImageProcessorFast

[[autodoc]] DeepseekVLImageProcessorFast

## DeepseekVLModel

[[autodoc]] DeepseekVLModel
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/deepseek_vl_hybrid.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ model = DeepseekVLHybridForConditionalGeneration.from_pretrained(

[[autodoc]] DeepseekVLHybridImageProcessor

## DeepseekVLHybridImageProcessorFast

[[autodoc]] DeepseekVLHybridImageProcessorFast

## DeepseekVLHybridModel

[[autodoc]] DeepseekVLHybridModel
Expand Down
14 changes: 9 additions & 5 deletions docs/source/en/model_doc/janus.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ Here is the example of visual understanding with a single image.
> Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
```python
import torch
from PIL import Image
import requests
import torch
from PIL import Image
import requests

from transformers import JanusForConditionalGeneration, JanusProcessor
from transformers import JanusForConditionalGeneration, JanusProcessor

model_id = "deepseek-community/Janus-Pro-1B"
# Prepare Input for generation.
Expand All @@ -64,7 +64,7 @@ messages = [

# Set generation mode to `text` to perform text generation.
processor = JanusProcessor.from_pretrained(model_id)
model = JanusForConditionalGeneration.from_pretrained(model_id,
model = JanusForConditionalGeneration.from_pretrained(model_id,
torch_dtype=torch.bfloat16,
device_map="auto")

Expand Down Expand Up @@ -209,6 +209,10 @@ for i, image in enumerate(images['pixel_values']):

[[autodoc]] JanusImageProcessor

## JanusImageProcessorFast

[[autodoc]] JanusImageProcessorFast

## JanusVisionModel

[[autodoc]] JanusVisionModel
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
("deepseek_vl", ("DeepseekVLImageProcessor")),
("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor")),
("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")),
("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")),
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")),
Expand Down Expand Up @@ -112,7 +112,7 @@
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
("janus", ("JanusImageProcessor")),
("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils import (
logging,
)
from ..auto import CONFIG_MAPPING, AutoConfig


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def resize(
self,
image: np.ndarray,
size: Union[dict[str, int], int],
background_color: Optional[tuple[int, int, int]] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
Expand All @@ -142,6 +143,10 @@ def resize(
Args:
image (`np.ndarray`):
Image to resize.
size (`dict[str, int]` or `int`):
The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`.
background_color (`tuple[int, int, int]`):
The background color to use for the padding.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
Expand All @@ -160,6 +165,7 @@ def resize(
Returns:
`np.ndarray`: The resized image.
"""
background_color = background_color if background_color is not None else self.background_color
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)

Expand Down Expand Up @@ -191,7 +197,7 @@ def resize(
# Expand and pad the images to obtain a square image of dimensions `size x size`
image = self.pad_to_square(
image=image,
background_color=self.background_color,
background_color=background_color,
input_data_format=input_data_format,
)
return image
Expand Down Expand Up @@ -406,9 +412,5 @@ def pad_to_square(

return result

def postprocess(self):
"""Applies post-processing to the decoded image tokens by reversing transformations applied during preprocessing."""
raise AttributeError("Not needed for DeepseekVL")


__all__ = ["DeepseekVLImageProcessor"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/deepseek_vl/modular_deepseek_vl.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_deepseek_vl.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import torch.nn.functional as F

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling, SizeDict
from ...processing_utils import Unpack
from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
)


if is_torch_available():
import torch


class DeepseekVLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
r"""
min_size (`int`, *optional*, defaults to 14):
The minimum allowed size for the resized image. Ensures that neither the height nor width
falls below this value after resizing.
"""

min_size: int


@auto_docstring
class DeepseekVLImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"height": 384, "width": 384}
min_size = 14
do_resize = True
do_rescale = True
do_normalize = True
valid_kwargs = DeepseekVLFastImageProcessorKwargs

def __init__(self, **kwargs: Unpack[DeepseekVLFastImageProcessorKwargs]):
super().__init__(**kwargs)
if kwargs.get("image_mean", None) is None:
background_color = (127, 127, 127)
else:
background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")])
self.background_color = tuple(background_color)

def resize(
self,
image: "torch.Tensor",
size: SizeDict,
min_size: int,
interpolation: "F.InterpolationMode" = None,
antialias: bool = True,
**kwargs,
) -> "torch.Tensor":
if size.height is None or size.width is None or size.height != size.width:
raise ValueError(
f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
)
size = size.height

height, width = image.shape[-2:]
max_size = max(height, width)

delta = size / max_size
# Largest side becomes `size` and the other side is scaled according to the aspect ratio.
output_size_nonpadded = SizeDict(
height=max(int(height * delta), min_size),
width=max(int(width * delta), min_size),
)

return super().resize(image, size=output_size_nonpadded, interpolation=interpolation, antialias=antialias)

def pad_to_square(
self,
images: "torch.Tensor",
background_color: Union[int, tuple[int, int, int]] = 0,
) -> "torch.Tensor":
"""
Pads an image to a square based on the longest edge.

Args:
images (`torch.Tensor`):
The images to pad.
background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.

Returns:
`torch.Tensor`: The padded images.
"""
height, width = images.shape[-2:]
num_channels = images.shape[1]
batch_size = images.shape[0]

if height == width:
return images

max_dim = max(height, width)

# Ensure background_color is the correct shape
if isinstance(background_color, int):
background_color = [background_color]
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)

padded_images = torch.zeros(
(batch_size, num_channels, max_dim, max_dim), dtype=images.dtype, device=images.device
)
for i, color in enumerate(background_color):
padded_images[:, i, :, :] = color
if width > height:
start = (max_dim - height) // 2
padded_images[:, :, start : start + height, :] = images
else:
start = (max_dim - width) // 2
padded_images[:, :, :, start : start + width] = images

return padded_images

def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
min_size: int,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
do_pad: bool = True,
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(
image=stacked_images, size=size, min_size=min_size, interpolation=interpolation
)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)

# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_pad:
stacked_images = self.pad_to_square(stacked_images, background_color=self.background_color)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)


__all__ = ["DeepseekVLImageProcessorFast"]
13 changes: 13 additions & 0 deletions src/transformers/models/deepseek_vl/modular_deepseek_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
from ..idefics.modeling_idefics import IdeficsBaseModelOutputWithPast, IdeficsCausalLMOutputWithPast
from ..janus.image_processing_janus import JanusImageProcessor
from ..janus.image_processing_janus_fast import JanusImageProcessorFast
from ..janus.modeling_janus import JanusForConditionalGeneration, JanusModel, JanusPreTrainedModel


Expand Down Expand Up @@ -181,13 +182,24 @@ def generate(self):


class DeepseekVLImageProcessor(JanusImageProcessor):
def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def postprocess(self):
raise AttributeError("Not needed for DeepseekVL")

def unnormalize(self):
raise AttributeError("Not needed for DeepseekVL")


class DeepseekVLImageProcessorFast(JanusImageProcessorFast):
def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def postprocess(self):
raise AttributeError("Not needed for DeepseekVL")


class DeepseekVLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {"padding": False},
Expand Down Expand Up @@ -322,5 +334,6 @@ def model_input_names(self):
"DeepseekVLModel",
"DeepseekVLForConditionalGeneration",
"DeepseekVLImageProcessor",
"DeepseekVLImageProcessorFast",
"DeepseekVLProcessor",
]
1 change: 1 addition & 0 deletions src/transformers/models/deepseek_vl_hybrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .configuration_deepseek_vl_hybrid import *
from .image_processing_deepseek_vl_fast_hybrid import *
from .image_processing_deepseek_vl_hybrid import *
from .image_processing_deepseek_vl_hybrid_fast import *
from .modeling_deepseek_vl_hybrid import *
from .processing_deepseek_vl_hybrid import *
else:
Expand Down
Loading