1
1
"""All builtin functions."""
2
2
3
3
import dataclasses
4
+ import functools
4
5
from typing import Annotated , Any , Literal
5
6
6
7
import numpy as np
@@ -101,69 +102,51 @@ def __call__(self, text: str) -> NDArray[np.float32]:
101
102
return result
102
103
103
104
104
- # Global ColPali model cache to avoid reloading models
105
- _COLPALI_MODEL_CACHE = {}
106
-
107
-
108
- class _ColPaliModelManager :
109
- """Shared model manager for ColPali models to avoid duplicate loading."""
110
-
111
- @staticmethod
112
- def get_model_and_processor (model_name : str ) -> dict [str , Any ]:
113
- """Get or load ColPali model and processor, with caching."""
114
- if model_name not in _COLPALI_MODEL_CACHE :
115
- try :
116
- from colpali_engine .models import ColPali , ColPaliProcessor # type: ignore[import-untyped]
117
- except ImportError as e :
118
- raise ImportError (
119
- "ColPali is not available. Make sure cocoindex is installed with ColPali support."
120
- ) from e
121
-
122
- model = ColPali .from_pretrained (model_name )
123
- processor = ColPaliProcessor .from_pretrained (model_name )
124
-
125
- # Get dimension from FastEmbed API
126
- dimension = _ColPaliModelManager ._detect_dimension ()
127
-
128
- _COLPALI_MODEL_CACHE [model_name ] = {
129
- "model" : model ,
130
- "processor" : processor ,
131
- "dimension" : dimension ,
132
- }
133
-
134
- return _COLPALI_MODEL_CACHE [model_name ]
135
-
136
- @staticmethod
137
- def _detect_dimension () -> int :
138
- """Detect ColPali embedding dimension using FastEmbed API."""
139
- try :
140
- from fastembed import LateInteractionMultimodalEmbedding
141
-
142
- # Use the standard FastEmbed ColPali model for dimension detection
143
- standard_colpali_model = "Qdrant/colpali-v1.3-fp16"
144
-
145
- supported_models = (
146
- LateInteractionMultimodalEmbedding .list_supported_models ()
147
- )
148
- for supported_model in supported_models :
149
- if supported_model ["model" ] == standard_colpali_model :
150
- dim = supported_model ["dim" ]
151
- if isinstance (dim , int ):
152
- return dim
153
- else :
154
- raise ValueError (
155
- f"Expected integer dimension, got { type (dim )} : { dim } "
156
- )
157
-
158
- raise ValueError (
159
- f"Could not find dimension for ColPali model in FastEmbed supported models"
160
- )
161
-
162
- except ImportError :
163
- raise ImportError (
164
- "FastEmbed is required for ColPali dimension detection. "
165
- "Install it with: pip install fastembed"
166
- )
105
+ @functools .cache
106
+ def _get_colpali_model_and_processor (model_name : str ) -> dict [str , Any ]:
107
+ """Get or load ColPali model and processor, with caching."""
108
+ try :
109
+ from colpali_engine .models import ColPali , ColPaliProcessor # type: ignore[import-untyped]
110
+ except ImportError as e :
111
+ raise ImportError (
112
+ "ColPali is not available. Make sure cocoindex is installed with ColPali support."
113
+ ) from e
114
+
115
+ model = ColPali .from_pretrained (model_name )
116
+ processor = ColPaliProcessor .from_pretrained (model_name )
117
+
118
+ # Get dimension from the actual model
119
+ dimension = _detect_colpali_dimension (model , processor )
120
+
121
+ return {
122
+ "model" : model ,
123
+ "processor" : processor ,
124
+ "dimension" : dimension ,
125
+ }
126
+
127
+
128
+ def _detect_colpali_dimension (model : Any , processor : Any ) -> int :
129
+ """Detect ColPali embedding dimension from the actual model config."""
130
+ # Try to access embedding dimension
131
+ if hasattr (model .config , "embedding_dim" ):
132
+ dim = model .config .embedding_dim
133
+ else :
134
+ # Fallback: infer from output shape with dummy data
135
+ from PIL import Image
136
+ import numpy as np
137
+ import torch
138
+
139
+ dummy_img = Image .fromarray (np .zeros ((224 , 224 , 3 ), np .uint8 ))
140
+ # Use the processor to process the dummy image
141
+ processed = processor .process_images ([dummy_img ])
142
+ with torch .no_grad ():
143
+ output = model (** processed )
144
+ dim = int (output .shape [- 1 ])
145
+ if isinstance (dim , int ):
146
+ return dim
147
+ else :
148
+ raise ValueError (f"Expected integer dimension, got { type (dim )} : { dim } " )
149
+ return dim
167
150
168
151
169
152
class ColPaliEmbedImage (op .FunctionSpec ):
@@ -198,9 +181,7 @@ class ColPaliEmbedImageExecutor:
198
181
def analyze (self , _img_bytes : Any ) -> type :
199
182
# Get shared model and dimension
200
183
if self ._cached_model_data is None :
201
- self ._cached_model_data = _ColPaliModelManager .get_model_and_processor (
202
- self .spec .model
203
- )
184
+ self ._cached_model_data = _get_colpali_model_and_processor (self .spec .model )
204
185
205
186
# Return multi-vector type: Variable patches x Fixed hidden dimension
206
187
dimension = self ._cached_model_data ["dimension" ]
@@ -218,9 +199,7 @@ def __call__(self, img_bytes: bytes) -> list[list[float]]:
218
199
219
200
# Get shared model and processor
220
201
if self ._cached_model_data is None :
221
- self ._cached_model_data = _ColPaliModelManager .get_model_and_processor (
222
- self .spec .model
223
- )
202
+ self ._cached_model_data = _get_colpali_model_and_processor (self .spec .model )
224
203
225
204
model = self ._cached_model_data ["model" ]
226
205
processor = self ._cached_model_data ["processor" ]
@@ -279,9 +258,7 @@ class ColPaliEmbedQueryExecutor:
279
258
def analyze (self , _query : Any ) -> type :
280
259
# Get shared model and dimension
281
260
if self ._cached_model_data is None :
282
- self ._cached_model_data = _ColPaliModelManager .get_model_and_processor (
283
- self .spec .model
284
- )
261
+ self ._cached_model_data = _get_colpali_model_and_processor (self .spec .model )
285
262
286
263
# Return multi-vector type: Variable tokens x Fixed hidden dimension
287
264
dimension = self ._cached_model_data ["dimension" ]
@@ -297,9 +274,7 @@ def __call__(self, query: str) -> list[list[float]]:
297
274
298
275
# Get shared model and processor
299
276
if self ._cached_model_data is None :
300
- self ._cached_model_data = _ColPaliModelManager .get_model_and_processor (
301
- self .spec .model
302
- )
277
+ self ._cached_model_data = _get_colpali_model_and_processor (self .spec .model )
303
278
304
279
model = self ._cached_model_data ["model" ]
305
280
processor = self ._cached_model_data ["processor" ]
0 commit comments