Skip to content

Commit 7b4d984

Browse files
authored
Add fast image processor Janus, Deepseek VL, Deepseek VL hybrid (#39739)
* add fast image processor Janus, deepseek_vl, deepseek_vl_hybrid * fix after review
1 parent 88ead3f commit 7b4d984

19 files changed

+1273
-102
lines changed

docs/source/en/model_doc/deepseek_vl.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ model = DeepseekVLForConditionalGeneration.from_pretrained(
209209

210210
[[autodoc]] DeepseekVLImageProcessor
211211

212+
## DeepseekVLImageProcessorFast
213+
214+
[[autodoc]] DeepseekVLImageProcessorFast
215+
212216
## DeepseekVLModel
213217

214218
[[autodoc]] DeepseekVLModel

docs/source/en/model_doc/deepseek_vl_hybrid.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ model = DeepseekVLHybridForConditionalGeneration.from_pretrained(
208208

209209
[[autodoc]] DeepseekVLHybridImageProcessor
210210

211+
## DeepseekVLHybridImageProcessorFast
212+
213+
[[autodoc]] DeepseekVLHybridImageProcessorFast
214+
211215
## DeepseekVLHybridModel
212216

213217
[[autodoc]] DeepseekVLHybridModel

docs/source/en/model_doc/janus.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ Here is the example of visual understanding with a single image.
4444
> Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
4545
4646
```python
47-
import torch
48-
from PIL import Image
49-
import requests
47+
import torch
48+
from PIL import Image
49+
import requests
5050

51-
from transformers import JanusForConditionalGeneration, JanusProcessor
51+
from transformers import JanusForConditionalGeneration, JanusProcessor
5252

5353
model_id = "deepseek-community/Janus-Pro-1B"
5454
# Prepare Input for generation.
@@ -64,7 +64,7 @@ messages = [
6464

6565
# Set generation mode to `text` to perform text generation.
6666
processor = JanusProcessor.from_pretrained(model_id)
67-
model = JanusForConditionalGeneration.from_pretrained(model_id,
67+
model = JanusForConditionalGeneration.from_pretrained(model_id,
6868
torch_dtype=torch.bfloat16,
6969
device_map="auto")
7070

@@ -209,6 +209,10 @@ for i, image in enumerate(images['pixel_values']):
209209

210210
[[autodoc]] JanusImageProcessor
211211

212+
## JanusImageProcessorFast
213+
214+
[[autodoc]] JanusImageProcessorFast
215+
212216
## JanusVisionModel
213217

214218
[[autodoc]] JanusVisionModel

src/transformers/models/auto/image_processing_auto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@
7878
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
7979
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
8080
("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
81-
("deepseek_vl", ("DeepseekVLImageProcessor")),
82-
("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor")),
81+
("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")),
82+
("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")),
8383
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
8484
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
8585
("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")),
@@ -113,7 +113,7 @@
113113
("imagegpt", ("ImageGPTImageProcessor",)),
114114
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
115115
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
116-
("janus", ("JanusImageProcessor")),
116+
("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
117117
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
118118
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
119119
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),

src/transformers/models/deepseek_vl/configuration_deepseek_vl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121

2222
from ...configuration_utils import PretrainedConfig
23-
from ...utils import logging
23+
from ...utils import (
24+
logging,
25+
)
2426
from ..auto import CONFIG_MAPPING, AutoConfig
2527

2628

src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def resize(
131131
self,
132132
image: np.ndarray,
133133
size: Union[dict[str, int], int],
134+
background_color: Optional[tuple[int, int, int]] = None,
134135
resample: PILImageResampling = PILImageResampling.BICUBIC,
135136
data_format: Optional[Union[str, ChannelDimension]] = None,
136137
input_data_format: Optional[Union[str, ChannelDimension]] = None,
@@ -142,6 +143,10 @@ def resize(
142143
Args:
143144
image (`np.ndarray`):
144145
Image to resize.
146+
size (`dict[str, int]` or `int`):
147+
The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`.
148+
background_color (`tuple[int, int, int]`):
149+
The background color to use for the padding.
145150
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
146151
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
147152
data_format (`ChannelDimension` or `str`, *optional*):
@@ -160,6 +165,7 @@ def resize(
160165
Returns:
161166
`np.ndarray`: The resized image.
162167
"""
168+
background_color = background_color if background_color is not None else self.background_color
163169
if input_data_format is None:
164170
input_data_format = infer_channel_dimension_format(image)
165171

@@ -191,7 +197,7 @@ def resize(
191197
# Expand and pad the images to obtain a square image of dimensions `size x size`
192198
image = self.pad_to_square(
193199
image=image,
194-
background_color=self.background_color,
200+
background_color=background_color,
195201
input_data_format=input_data_format,
196202
)
197203
return image
@@ -406,9 +412,5 @@ def pad_to_square(
406412

407413
return result
408414

409-
def postprocess(self):
410-
"""Applies post-processing to the decoded image tokens by reversing transformations applied during preprocessing."""
411-
raise AttributeError("Not needed for DeepseekVL")
412-
413415

414416
__all__ = ["DeepseekVLImageProcessor"]
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2+
# This file was automatically generated from src/transformers/models/deepseek_vl/modular_deepseek_vl.py.
3+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
4+
# the file from the modular. If any change should be done, please apply the change to the
5+
# modular_deepseek_vl.py file directly. One of our CI enforces this.
6+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7+
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
from typing import Optional, Union
22+
23+
import torch.nn.functional as F
24+
25+
from ...image_processing_utils import BatchFeature
26+
from ...image_processing_utils_fast import (
27+
BaseImageProcessorFast,
28+
DefaultFastImageProcessorKwargs,
29+
group_images_by_shape,
30+
reorder_images,
31+
)
32+
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling, SizeDict
33+
from ...processing_utils import Unpack
34+
from ...utils import (
35+
TensorType,
36+
auto_docstring,
37+
is_torch_available,
38+
)
39+
40+
41+
if is_torch_available():
42+
import torch
43+
44+
45+
class DeepseekVLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
46+
r"""
47+
min_size (`int`, *optional*, defaults to 14):
48+
The minimum allowed size for the resized image. Ensures that neither the height nor width
49+
falls below this value after resizing.
50+
"""
51+
52+
min_size: int
53+
54+
55+
@auto_docstring
56+
class DeepseekVLImageProcessorFast(BaseImageProcessorFast):
57+
resample = PILImageResampling.BICUBIC
58+
image_mean = OPENAI_CLIP_MEAN
59+
image_std = OPENAI_CLIP_STD
60+
size = {"height": 384, "width": 384}
61+
min_size = 14
62+
do_resize = True
63+
do_rescale = True
64+
do_normalize = True
65+
valid_kwargs = DeepseekVLFastImageProcessorKwargs
66+
67+
def __init__(self, **kwargs: Unpack[DeepseekVLFastImageProcessorKwargs]):
68+
super().__init__(**kwargs)
69+
if kwargs.get("image_mean", None) is None:
70+
background_color = (127, 127, 127)
71+
else:
72+
background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")])
73+
self.background_color = tuple(background_color)
74+
75+
def resize(
76+
self,
77+
image: "torch.Tensor",
78+
size: SizeDict,
79+
min_size: int,
80+
interpolation: "F.InterpolationMode" = None,
81+
antialias: bool = True,
82+
**kwargs,
83+
) -> "torch.Tensor":
84+
if size.height is None or size.width is None or size.height != size.width:
85+
raise ValueError(
86+
f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
87+
)
88+
size = size.height
89+
90+
height, width = image.shape[-2:]
91+
max_size = max(height, width)
92+
93+
delta = size / max_size
94+
# Largest side becomes `size` and the other side is scaled according to the aspect ratio.
95+
output_size_nonpadded = SizeDict(
96+
height=max(int(height * delta), min_size),
97+
width=max(int(width * delta), min_size),
98+
)
99+
100+
return super().resize(image, size=output_size_nonpadded, interpolation=interpolation, antialias=antialias)
101+
102+
def pad_to_square(
103+
self,
104+
images: "torch.Tensor",
105+
background_color: Union[int, tuple[int, int, int]] = 0,
106+
) -> "torch.Tensor":
107+
"""
108+
Pads an image to a square based on the longest edge.
109+
110+
Args:
111+
images (`torch.Tensor`):
112+
The images to pad.
113+
background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
114+
The color to use for the padding. Can be an integer for single channel or a
115+
tuple of integers representing for multi-channel images. If passed as integer
116+
in mutli-channel mode, it will default to `0` in subsequent channels.
117+
118+
Returns:
119+
`torch.Tensor`: The padded images.
120+
"""
121+
height, width = images.shape[-2:]
122+
num_channels = images.shape[1]
123+
batch_size = images.shape[0]
124+
125+
if height == width:
126+
return images
127+
128+
max_dim = max(height, width)
129+
130+
# Ensure background_color is the correct shape
131+
if isinstance(background_color, int):
132+
background_color = [background_color]
133+
elif len(background_color) != num_channels:
134+
raise ValueError(
135+
f"background_color must have no more than {num_channels} elements to match the number of channels"
136+
)
137+
138+
padded_images = torch.zeros(
139+
(batch_size, num_channels, max_dim, max_dim), dtype=images.dtype, device=images.device
140+
)
141+
for i, color in enumerate(background_color):
142+
padded_images[:, i, :, :] = color
143+
if width > height:
144+
start = (max_dim - height) // 2
145+
padded_images[:, :, start : start + height, :] = images
146+
else:
147+
start = (max_dim - width) // 2
148+
padded_images[:, :, :, start : start + width] = images
149+
150+
return padded_images
151+
152+
def _preprocess(
153+
self,
154+
images: list["torch.Tensor"],
155+
do_resize: bool,
156+
size: SizeDict,
157+
min_size: int,
158+
interpolation: Optional["F.InterpolationMode"],
159+
do_rescale: bool,
160+
rescale_factor: float,
161+
do_normalize: bool,
162+
image_mean: Optional[Union[float, list[float]]],
163+
image_std: Optional[Union[float, list[float]]],
164+
disable_grouping: Optional[bool],
165+
return_tensors: Optional[Union[str, TensorType]],
166+
do_pad: bool = True,
167+
**kwargs,
168+
) -> BatchFeature:
169+
# Group images by size for batched resizing
170+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
171+
resized_images_grouped = {}
172+
for shape, stacked_images in grouped_images.items():
173+
if do_resize:
174+
stacked_images = self.resize(
175+
image=stacked_images, size=size, min_size=min_size, interpolation=interpolation
176+
)
177+
resized_images_grouped[shape] = stacked_images
178+
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
179+
180+
# Group images by size for further processing
181+
# Needed in case do_resize is False, or resize returns images with different sizes
182+
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
183+
processed_images_grouped = {}
184+
for shape, stacked_images in grouped_images.items():
185+
if do_pad:
186+
stacked_images = self.pad_to_square(stacked_images, background_color=self.background_color)
187+
# Fused rescale and normalize
188+
stacked_images = self.rescale_and_normalize(
189+
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
190+
)
191+
processed_images_grouped[shape] = stacked_images
192+
193+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
194+
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
195+
196+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
197+
198+
199+
__all__ = ["DeepseekVLImageProcessorFast"]

src/transformers/models/deepseek_vl/modular_deepseek_vl.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
3434
from ..idefics.modeling_idefics import IdeficsBaseModelOutputWithPast, IdeficsCausalLMOutputWithPast
3535
from ..janus.image_processing_janus import JanusImageProcessor
36+
from ..janus.image_processing_janus_fast import JanusImageProcessorFast
3637
from ..janus.modeling_janus import JanusForConditionalGeneration, JanusModel, JanusPreTrainedModel
3738

3839

@@ -181,13 +182,24 @@ def generate(self):
181182

182183

183184
class DeepseekVLImageProcessor(JanusImageProcessor):
185+
def __init__(self, **super_kwargs):
186+
super().__init__(**super_kwargs)
187+
184188
def postprocess(self):
185189
raise AttributeError("Not needed for DeepseekVL")
186190

187191
def unnormalize(self):
188192
raise AttributeError("Not needed for DeepseekVL")
189193

190194

195+
class DeepseekVLImageProcessorFast(JanusImageProcessorFast):
196+
def __init__(self, **super_kwargs):
197+
super().__init__(**super_kwargs)
198+
199+
def postprocess(self):
200+
raise AttributeError("Not needed for DeepseekVL")
201+
202+
191203
class DeepseekVLProcessorKwargs(ProcessingKwargs, total=False):
192204
_defaults = {
193205
"text_kwargs": {"padding": False},
@@ -322,5 +334,6 @@ def model_input_names(self):
322334
"DeepseekVLModel",
323335
"DeepseekVLForConditionalGeneration",
324336
"DeepseekVLImageProcessor",
337+
"DeepseekVLImageProcessorFast",
325338
"DeepseekVLProcessor",
326339
]

src/transformers/models/deepseek_vl_hybrid/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .configuration_deepseek_vl_hybrid import *
2222
from .image_processing_deepseek_vl_fast_hybrid import *
2323
from .image_processing_deepseek_vl_hybrid import *
24+
from .image_processing_deepseek_vl_hybrid_fast import *
2425
from .modeling_deepseek_vl_hybrid import *
2526
from .processing_deepseek_vl_hybrid import *
2627
else:

0 commit comments

Comments
 (0)