Skip to content

Commit 6355024

Browse files
badmonster0georgeh0
authored andcommitted
Optimize ColPali with functools.cache and add colpali feature
- Use @functools.cache for model caching instead of manual dict - Add 'colpali' optional dependency separate from 'embeddings' - Fix dimension detection and LAN access for frontend
1 parent 8f45c79 commit 6355024

File tree

4 files changed

+59
-80
lines changed

4 files changed

+59
-80
lines changed

examples/image_search/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ export OLLAMA_MODEL="gemma3" # Optional, for caption generation
6262
- Install dependencies:
6363
```
6464
pip install -e .
65-
pip install 'cocoindex[embeddings]' # Adds ColPali and sentence-transformers support
65+
pip install 'cocoindex[colpali]' # Adds ColPali support
6666
```
6767

6868
- Configure model (optional):

examples/image_search/frontend/src/App.jsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import React, { useState } from 'react';
22

3-
const API_URL = 'http://localhost:8000/search'; // Adjust this to your backend search endpoint
3+
const API_URL = `http://${window.location.hostname}:8000/search`;
44

55
export default function App() {
66
const [query, setQuery] = useState('');
@@ -42,7 +42,7 @@ export default function App() {
4242
{results.length === 0 && !loading && <div>No results</div>}
4343
{results.map((result, idx) => (
4444
<div key={idx} className="result-card">
45-
<img src={`http://localhost:8000/img/${result.filename}`} alt={result.filename} className="result-img" />
45+
<img src={`http://${window.location.hostname}:8000/img/${result.filename}`} alt={result.filename} className="result-img" />
4646
<div className="score">Score: {result.score?.toFixed(3)}</div>
4747
</div>
4848
))}

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@ features = ["pyo3/extension-module"]
3131
[project.optional-dependencies]
3232
dev = ["pytest", "pytest-asyncio", "ruff", "mypy", "pre-commit"]
3333

34-
embeddings = ["sentence-transformers>=3.3.1", "colpali-engine", "fastembed"]
35-
all = ["sentence-transformers>=3.3.1", "colpali-engine", "fastembed"]
34+
embeddings = ["sentence-transformers>=3.3.1"]
35+
colpali = ["colpali-engine"]
36+
37+
# We need to repeat the dependency above to make it available for the `all` feature.
38+
# Indirect dependencies such as "cocoindex[embeddings]" will not work for local development.
39+
all = ["sentence-transformers>=3.3.1", "colpali-engine"]
3640

3741
[tool.mypy]
3842
python_version = "3.11"

python/cocoindex/functions.py

Lines changed: 50 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""All builtin functions."""
22

33
import dataclasses
4+
import functools
45
from typing import Annotated, Any, Literal
56

67
import numpy as np
@@ -101,69 +102,51 @@ def __call__(self, text: str) -> NDArray[np.float32]:
101102
return result
102103

103104

104-
# Global ColPali model cache to avoid reloading models
105-
_COLPALI_MODEL_CACHE = {}
106-
107-
108-
class _ColPaliModelManager:
109-
"""Shared model manager for ColPali models to avoid duplicate loading."""
110-
111-
@staticmethod
112-
def get_model_and_processor(model_name: str) -> dict[str, Any]:
113-
"""Get or load ColPali model and processor, with caching."""
114-
if model_name not in _COLPALI_MODEL_CACHE:
115-
try:
116-
from colpali_engine.models import ColPali, ColPaliProcessor # type: ignore[import-untyped]
117-
except ImportError as e:
118-
raise ImportError(
119-
"ColPali is not available. Make sure cocoindex is installed with ColPali support."
120-
) from e
121-
122-
model = ColPali.from_pretrained(model_name)
123-
processor = ColPaliProcessor.from_pretrained(model_name)
124-
125-
# Get dimension from FastEmbed API
126-
dimension = _ColPaliModelManager._detect_dimension()
127-
128-
_COLPALI_MODEL_CACHE[model_name] = {
129-
"model": model,
130-
"processor": processor,
131-
"dimension": dimension,
132-
}
133-
134-
return _COLPALI_MODEL_CACHE[model_name]
135-
136-
@staticmethod
137-
def _detect_dimension() -> int:
138-
"""Detect ColPali embedding dimension using FastEmbed API."""
139-
try:
140-
from fastembed import LateInteractionMultimodalEmbedding
141-
142-
# Use the standard FastEmbed ColPali model for dimension detection
143-
standard_colpali_model = "Qdrant/colpali-v1.3-fp16"
144-
145-
supported_models = (
146-
LateInteractionMultimodalEmbedding.list_supported_models()
147-
)
148-
for supported_model in supported_models:
149-
if supported_model["model"] == standard_colpali_model:
150-
dim = supported_model["dim"]
151-
if isinstance(dim, int):
152-
return dim
153-
else:
154-
raise ValueError(
155-
f"Expected integer dimension, got {type(dim)}: {dim}"
156-
)
157-
158-
raise ValueError(
159-
f"Could not find dimension for ColPali model in FastEmbed supported models"
160-
)
161-
162-
except ImportError:
163-
raise ImportError(
164-
"FastEmbed is required for ColPali dimension detection. "
165-
"Install it with: pip install fastembed"
166-
)
105+
@functools.cache
106+
def _get_colpali_model_and_processor(model_name: str) -> dict[str, Any]:
107+
"""Get or load ColPali model and processor, with caching."""
108+
try:
109+
from colpali_engine.models import ColPali, ColPaliProcessor # type: ignore[import-untyped]
110+
except ImportError as e:
111+
raise ImportError(
112+
"ColPali is not available. Make sure cocoindex is installed with ColPali support."
113+
) from e
114+
115+
model = ColPali.from_pretrained(model_name)
116+
processor = ColPaliProcessor.from_pretrained(model_name)
117+
118+
# Get dimension from the actual model
119+
dimension = _detect_colpali_dimension(model, processor)
120+
121+
return {
122+
"model": model,
123+
"processor": processor,
124+
"dimension": dimension,
125+
}
126+
127+
128+
def _detect_colpali_dimension(model: Any, processor: Any) -> int:
129+
"""Detect ColPali embedding dimension from the actual model config."""
130+
# Try to access embedding dimension
131+
if hasattr(model.config, "embedding_dim"):
132+
dim = model.config.embedding_dim
133+
else:
134+
# Fallback: infer from output shape with dummy data
135+
from PIL import Image
136+
import numpy as np
137+
import torch
138+
139+
dummy_img = Image.fromarray(np.zeros((224, 224, 3), np.uint8))
140+
# Use the processor to process the dummy image
141+
processed = processor.process_images([dummy_img])
142+
with torch.no_grad():
143+
output = model(**processed)
144+
dim = int(output.shape[-1])
145+
if isinstance(dim, int):
146+
return dim
147+
else:
148+
raise ValueError(f"Expected integer dimension, got {type(dim)}: {dim}")
149+
return dim
167150

168151

169152
class ColPaliEmbedImage(op.FunctionSpec):
@@ -198,9 +181,7 @@ class ColPaliEmbedImageExecutor:
198181
def analyze(self, _img_bytes: Any) -> type:
199182
# Get shared model and dimension
200183
if self._cached_model_data is None:
201-
self._cached_model_data = _ColPaliModelManager.get_model_and_processor(
202-
self.spec.model
203-
)
184+
self._cached_model_data = _get_colpali_model_and_processor(self.spec.model)
204185

205186
# Return multi-vector type: Variable patches x Fixed hidden dimension
206187
dimension = self._cached_model_data["dimension"]
@@ -218,9 +199,7 @@ def __call__(self, img_bytes: bytes) -> list[list[float]]:
218199

219200
# Get shared model and processor
220201
if self._cached_model_data is None:
221-
self._cached_model_data = _ColPaliModelManager.get_model_and_processor(
222-
self.spec.model
223-
)
202+
self._cached_model_data = _get_colpali_model_and_processor(self.spec.model)
224203

225204
model = self._cached_model_data["model"]
226205
processor = self._cached_model_data["processor"]
@@ -279,9 +258,7 @@ class ColPaliEmbedQueryExecutor:
279258
def analyze(self, _query: Any) -> type:
280259
# Get shared model and dimension
281260
if self._cached_model_data is None:
282-
self._cached_model_data = _ColPaliModelManager.get_model_and_processor(
283-
self.spec.model
284-
)
261+
self._cached_model_data = _get_colpali_model_and_processor(self.spec.model)
285262

286263
# Return multi-vector type: Variable tokens x Fixed hidden dimension
287264
dimension = self._cached_model_data["dimension"]
@@ -297,9 +274,7 @@ def __call__(self, query: str) -> list[list[float]]:
297274

298275
# Get shared model and processor
299276
if self._cached_model_data is None:
300-
self._cached_model_data = _ColPaliModelManager.get_model_and_processor(
301-
self.spec.model
302-
)
277+
self._cached_model_data = _get_colpali_model_and_processor(self.spec.model)
303278

304279
model = self._cached_model_data["model"]
305280
processor = self._cached_model_data["processor"]

0 commit comments

Comments
 (0)