@@ -99,3 +99,150 @@ def __call__(self, text: str) -> NDArray[np.float32]:
99
99
assert self ._model is not None
100
100
result : NDArray [np .float32 ] = self ._model .encode (text , convert_to_numpy = True )
101
101
return result
102
+
103
+
104
+ # ColPali model cache for ColPali embedding functions
105
+ _colpali_model_cache = {}
106
+
107
+
108
+ def get_colpali_model (model : str ) -> tuple [Any , Any , int ]:
109
+ """Get or load ColPali model and processor."""
110
+ global _colpali_model_cache
111
+ if model not in _colpali_model_cache :
112
+ try :
113
+ from colpali_engine .models import ColPali , ColPaliProcessor # type: ignore[import-untyped]
114
+ except ImportError as e :
115
+ raise ImportError (
116
+ "ColPali is not available. Make sure cocoindex is installed with ColPali support."
117
+ ) from e
118
+
119
+ model_instance = ColPali .from_pretrained (model )
120
+ processor_instance = ColPaliProcessor .from_pretrained (model )
121
+
122
+ # Try to get dimension from FastEmbed API first
123
+ output_dim = None
124
+ try :
125
+ from fastembed import LateInteractionMultimodalEmbedding
126
+
127
+ # Use the standard FastEmbed ColPali model for dimension detection
128
+ # All ColPali variants should have the same embedding dimension
129
+ standard_colpali_model = "Qdrant/colpali-v1.3-fp16"
130
+
131
+ # Try to find the model in FastEmbed's supported models
132
+ supported_models = (
133
+ LateInteractionMultimodalEmbedding .list_supported_models ()
134
+ )
135
+ for supported_model in supported_models :
136
+ if supported_model ["model" ] == standard_colpali_model :
137
+ output_dim = supported_model ["dim" ]
138
+ break
139
+
140
+ except Exception :
141
+ # FastEmbed API failed, will fall back to model config
142
+ pass
143
+
144
+ # Fallback to model config if FastEmbed API failed
145
+ if output_dim is None :
146
+ if hasattr (model_instance , "config" ):
147
+ # Try different config attributes that might contain the hidden dimension
148
+ if hasattr (model_instance .config , "hidden_size" ):
149
+ output_dim = model_instance .config .hidden_size
150
+ elif hasattr (model_instance .config , "text_config" ) and hasattr (
151
+ model_instance .config .text_config , "hidden_size"
152
+ ):
153
+ output_dim = model_instance .config .text_config .hidden_size
154
+ elif hasattr (model_instance .config , "vision_config" ) and hasattr (
155
+ model_instance .config .vision_config , "hidden_size"
156
+ ):
157
+ output_dim = model_instance .config .vision_config .hidden_size
158
+ else :
159
+ raise ValueError (
160
+ f"Could not find hidden_size in model config for { model } . Config attributes: { dir (model_instance .config )} "
161
+ )
162
+ else :
163
+ raise ValueError (
164
+ f"Model { model } has no config attribute. Model attributes: { dir (model_instance )} "
165
+ )
166
+
167
+ _colpali_model_cache [model ] = {
168
+ "model" : model_instance ,
169
+ "processor" : processor_instance ,
170
+ "dimension" : output_dim ,
171
+ }
172
+ return (
173
+ _colpali_model_cache [model ]["model" ],
174
+ _colpali_model_cache [model ]["processor" ],
175
+ _colpali_model_cache [model ]["dimension" ],
176
+ )
177
+
178
+
179
+ def get_colpali_dimension (model : str ) -> int :
180
+ """Get the output dimension for a ColPali model."""
181
+ _ , _ , dimension = get_colpali_model (model )
182
+ return int (dimension )
183
+
184
+
185
+ def colpali_embed_image (img_bytes : bytes , model : str ) -> list [list [float ]]:
186
+ """Embed image using ColPali model, returning multi-vector format."""
187
+ try :
188
+ from PIL import Image
189
+ import torch
190
+ import io
191
+ except ImportError as e :
192
+ raise ImportError (
193
+ "Required dependencies (PIL, torch) are missing for ColPali image embedding."
194
+ ) from e
195
+
196
+ colpali_model , processor , expected_dim = get_colpali_model (model )
197
+ pil_image = Image .open (io .BytesIO (img_bytes )).convert ("RGB" )
198
+ inputs = processor .process_images ([pil_image ])
199
+ with torch .no_grad ():
200
+ embeddings = colpali_model (** inputs )
201
+
202
+ # Return multi-vector format: [patches, hidden_dim]
203
+ if len (embeddings .shape ) != 3 :
204
+ raise ValueError (
205
+ f"Expected 3D tensor [batch, patches, hidden_dim], got shape { embeddings .shape } "
206
+ )
207
+
208
+ # Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
209
+ patch_embeddings = embeddings [0 ] # Remove batch dimension
210
+
211
+ # Convert to list of lists: [[patch1_embedding], [patch2_embedding], ...]
212
+ result = []
213
+ for patch in patch_embeddings :
214
+ result .append (patch .cpu ().numpy ().tolist ())
215
+
216
+ return result
217
+
218
+
219
+ def colpali_embed_query (query : str , model : str ) -> list [list [float ]]:
220
+ """Embed query using ColPali model, returning multi-vector format."""
221
+ try :
222
+ import torch
223
+ import numpy as np
224
+ except ImportError as e :
225
+ raise ImportError (
226
+ "Required dependencies (torch, numpy) are missing for ColPali query embedding."
227
+ ) from e
228
+
229
+ colpali_model , processor , target_dimension = get_colpali_model (model )
230
+ inputs = processor .process_queries ([query ])
231
+ with torch .no_grad ():
232
+ embeddings = colpali_model (** inputs )
233
+
234
+ # Return multi-vector format: [tokens, hidden_dim]
235
+ if len (embeddings .shape ) != 3 :
236
+ raise ValueError (
237
+ f"Expected 3D tensor [batch, tokens, hidden_dim], got shape { embeddings .shape } "
238
+ )
239
+
240
+ # Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
241
+ token_embeddings = embeddings [0 ] # Remove batch dimension
242
+
243
+ # Convert to list of lists: [[token1_embedding], [token2_embedding], ...]
244
+ result = []
245
+ for token in token_embeddings :
246
+ result .append (token .cpu ().numpy ().tolist ())
247
+
248
+ return result
0 commit comments