Skip to content

Commit 143d57c

Browse files
committed
Move functon to functions.py and use multi-vector search
1 parent fb04748 commit 143d57c

File tree

4 files changed

+184
-68
lines changed

4 files changed

+184
-68
lines changed

examples/image_search_colpali/README.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ We appreciate a star ⭐ at [CocoIndex Github](https://github.com/cocoindex-io/c
1010

1111
## Technologies
1212
- CocoIndex for ETL and live update
13-
- **ColPali** - Multimodal Embeddings Model for images and query
14-
- Qdrant for Vector Storage (supports both gRPC and HTTP)
13+
- **ColPali** - Multimodal Embeddings Model for images and query with multi-vector late interaction
14+
- Qdrant for Vector Storage with multi-vector support and MaxSim scoring (supports both gRPC and HTTP)
1515
- FastAPI for backend
1616
- Ollama (Optional) for generating image captions using `gemma3` or other models
1717

@@ -55,6 +55,17 @@ export OLLAMA_MODEL="gemma3" # Optional, for caption generation
5555
```
5656
pip install -e .
5757
```
58+
Note: ColPali embedding support is included in the cocoindex library with the `[embeddings]` extra.
59+
60+
- The app automatically detects the ColPali model dimension and uses multi-vector embeddings with MaxSim scoring for optimal search performance.
61+
62+
## Supported Models
63+
- Default: `vidore/colpali-v1.2` (128-dimensional embeddings)
64+
- Also supports: `vidore/colpali-v1.1`, `vidore/colpali-v1.3`, and other ColPali variants
65+
- To use a different model, set the environment variable:
66+
```sh
67+
export COLPALI_MODEL="vidore/colpali-v1.3"
68+
```
5869

5970
- Run Backend
6071
```

examples/image_search_colpali/main.py

Lines changed: 24 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import io
44
import os
5+
import typing
56
from contextlib import asynccontextmanager
67
from typing import Any, Literal
78

@@ -13,7 +14,6 @@
1314
from fastapi.staticfiles import StaticFiles
1415
from PIL import Image
1516
from qdrant_client import QdrantClient
16-
from colpali_engine.models import ColPali, ColPaliProcessor
1717

1818

1919
# --- Config ---
@@ -29,76 +29,29 @@
2929
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/")
3030
QDRANT_COLLECTION = "ImageSearchColpali"
3131
COLPALI_MODEL_NAME = os.getenv("COLPALI_MODEL", "vidore/colpali-v1.2")
32-
COLPALI_MODEL_DIMENSION = 1031 # Set to match ColPali's output
33-
34-
# --- ColPali model cache and embedding functions ---
35-
_colpali_model_cache = {}
36-
37-
38-
def get_colpali_model(model: str = COLPALI_MODEL_NAME):
39-
global _colpali_model_cache
40-
if model not in _colpali_model_cache:
41-
print(f"Loading ColPali model: {model}")
42-
_colpali_model_cache[model] = {
43-
"model": ColPali.from_pretrained(model),
44-
"processor": ColPaliProcessor.from_pretrained(model),
45-
}
46-
return _colpali_model_cache[model]["model"], _colpali_model_cache[model][
47-
"processor"
48-
]
49-
50-
51-
def colpali_embed_image(
52-
img_bytes: bytes, model: str = COLPALI_MODEL_NAME
53-
) -> list[float]:
54-
from PIL import Image
55-
import torch
56-
import io
57-
58-
colpali_model, processor = get_colpali_model(model)
59-
pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
60-
inputs = processor.process_images([pil_image])
61-
with torch.no_grad():
62-
embeddings = colpali_model(**inputs)
63-
pooled_embedding = embeddings.mean(dim=-1)
64-
result = pooled_embedding[0].cpu().numpy() # [1031]
65-
return result.tolist()
66-
67-
68-
def colpali_embed_query(query: str, model: str = COLPALI_MODEL_NAME) -> list[float]:
69-
import torch
70-
import numpy as np
71-
72-
colpali_model, processor = get_colpali_model(model)
73-
inputs = processor.process_queries([query])
74-
with torch.no_grad():
75-
embeddings = colpali_model(**inputs)
76-
pooled_embedding = embeddings.mean(dim=-1)
77-
query_tokens = pooled_embedding[0].cpu().numpy() # [15]
78-
target_length = COLPALI_MODEL_DIMENSION
79-
result = np.zeros(target_length, dtype=np.float32)
80-
result[: min(len(query_tokens), target_length)] = query_tokens[:target_length]
81-
return result.tolist()
82-
83-
84-
# --- End ColPali embedding functions ---
32+
# Get ColPali embedding dimension dynamically from model
33+
HIDDEN_DIM = cocoindex.functions.get_colpali_dimension(COLPALI_MODEL_NAME)
34+
print(
35+
f"📐 Using ColPali model {COLPALI_MODEL_NAME} with {HIDDEN_DIM} hidden dimensions"
36+
)
8537

8638

87-
def embed_query(text: str) -> list[float]:
39+
def embed_query(text: str) -> list[list[float]]:
8840
"""
89-
Embed the caption using ColPali model.
41+
Embed the caption using ColPali model, returning multi-vector format.
9042
"""
91-
return colpali_embed_query(text, model=COLPALI_MODEL_NAME)
43+
return cocoindex.functions.colpali_embed_query(text, model=COLPALI_MODEL_NAME)
9244

9345

9446
@cocoindex.op.function(cache=True, behavior_version=1, gpu=True)
9547
def embed_image(
9648
img_bytes: bytes,
97-
) -> cocoindex.Vector[cocoindex.Float32, Literal[COLPALI_MODEL_DIMENSION]]:
49+
) -> cocoindex.Vector[cocoindex.Vector[cocoindex.Float32, typing.Literal[HIDDEN_DIM]]]:
9850
"""
99-
Convert image to embedding using ColPali model.
51+
Convert image to embedding using ColPali model, returning multi-vector format.
52+
Returns variable number of patches, each with model-specific dimensional embeddings.
10053
"""
101-
return colpali_embed_image(img_bytes, model=COLPALI_MODEL_NAME)
54+
return cocoindex.functions.colpali_embed_image(img_bytes, model=COLPALI_MODEL_NAME)
10255

10356

10457
@cocoindex.flow_def(name="ImageObjectEmbeddingColpali")
@@ -189,24 +142,30 @@ def search(
189142
q: str = Query(..., description="Search query"),
190143
limit: int = Query(5, description="Number of results"),
191144
) -> Any:
192-
# Get the embedding for the query
145+
# Get the multi-vector embedding for the query
193146
query_embedding = embed_query(q)
147+
print(
148+
f"🔍 Query multi-vector shape: {len(query_embedding)} tokens x {len(query_embedding[0]) if query_embedding else 0} dims"
149+
)
194150

195-
# Search in Qdrant
196-
search_results = app.state.qdrant_client.search(
151+
# Search in Qdrant with multi-vector MaxSim scoring using query_points API
152+
search_results = app.state.qdrant_client.query_points(
197153
collection_name=QDRANT_COLLECTION,
198-
query_vector=("embedding", query_embedding),
154+
query=query_embedding, # Multi-vector format: list[list[float]]
155+
using="embedding", # Specify the vector field name
199156
limit=limit,
200157
with_payload=True,
201158
)
202159

160+
print(f"📈 Found {len(search_results.points)} results with MaxSim scoring")
161+
203162
return {
204163
"results": [
205164
{
206165
"filename": result.payload["filename"],
207166
"score": result.score,
208167
"caption": result.payload.get("caption"),
209168
}
210-
for result in search_results
169+
for result in search_results.points
211170
]
212171
}

examples/image_search_colpali/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ dependencies = [
1010
"torch>=2.0.0",
1111
"qdrant-client>=1.14.2",
1212
"uvicorn>=0.34.3",
13-
"colpali-engine>=0.1.0",
1413
"Pillow>=10.0.0",
1514
"numpy>=1.24.0",
1615
]

python/cocoindex/functions.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,150 @@ def __call__(self, text: str) -> NDArray[np.float32]:
9999
assert self._model is not None
100100
result: NDArray[np.float32] = self._model.encode(text, convert_to_numpy=True)
101101
return result
102+
103+
104+
# ColPali model cache for ColPali embedding functions
105+
_colpali_model_cache = {}
106+
107+
108+
def get_colpali_model(model: str) -> tuple[Any, Any, int]:
109+
"""Get or load ColPali model and processor."""
110+
global _colpali_model_cache
111+
if model not in _colpali_model_cache:
112+
try:
113+
from colpali_engine.models import ColPali, ColPaliProcessor # type: ignore[import-untyped]
114+
except ImportError as e:
115+
raise ImportError(
116+
"ColPali is not available. Make sure cocoindex is installed with ColPali support."
117+
) from e
118+
119+
model_instance = ColPali.from_pretrained(model)
120+
processor_instance = ColPaliProcessor.from_pretrained(model)
121+
122+
# Try to get dimension from FastEmbed API first
123+
output_dim = None
124+
try:
125+
from fastembed import LateInteractionMultimodalEmbedding
126+
127+
# Use the standard FastEmbed ColPali model for dimension detection
128+
# All ColPali variants should have the same embedding dimension
129+
standard_colpali_model = "Qdrant/colpali-v1.3-fp16"
130+
131+
# Try to find the model in FastEmbed's supported models
132+
supported_models = (
133+
LateInteractionMultimodalEmbedding.list_supported_models()
134+
)
135+
for supported_model in supported_models:
136+
if supported_model["model"] == standard_colpali_model:
137+
output_dim = supported_model["dim"]
138+
break
139+
140+
except Exception:
141+
# FastEmbed API failed, will fall back to model config
142+
pass
143+
144+
# Fallback to model config if FastEmbed API failed
145+
if output_dim is None:
146+
if hasattr(model_instance, "config"):
147+
# Try different config attributes that might contain the hidden dimension
148+
if hasattr(model_instance.config, "hidden_size"):
149+
output_dim = model_instance.config.hidden_size
150+
elif hasattr(model_instance.config, "text_config") and hasattr(
151+
model_instance.config.text_config, "hidden_size"
152+
):
153+
output_dim = model_instance.config.text_config.hidden_size
154+
elif hasattr(model_instance.config, "vision_config") and hasattr(
155+
model_instance.config.vision_config, "hidden_size"
156+
):
157+
output_dim = model_instance.config.vision_config.hidden_size
158+
else:
159+
raise ValueError(
160+
f"Could not find hidden_size in model config for {model}. Config attributes: {dir(model_instance.config)}"
161+
)
162+
else:
163+
raise ValueError(
164+
f"Model {model} has no config attribute. Model attributes: {dir(model_instance)}"
165+
)
166+
167+
_colpali_model_cache[model] = {
168+
"model": model_instance,
169+
"processor": processor_instance,
170+
"dimension": output_dim,
171+
}
172+
return (
173+
_colpali_model_cache[model]["model"],
174+
_colpali_model_cache[model]["processor"],
175+
_colpali_model_cache[model]["dimension"],
176+
)
177+
178+
179+
def get_colpali_dimension(model: str) -> int:
180+
"""Get the output dimension for a ColPali model."""
181+
_, _, dimension = get_colpali_model(model)
182+
return int(dimension)
183+
184+
185+
def colpali_embed_image(img_bytes: bytes, model: str) -> list[list[float]]:
186+
"""Embed image using ColPali model, returning multi-vector format."""
187+
try:
188+
from PIL import Image
189+
import torch
190+
import io
191+
except ImportError as e:
192+
raise ImportError(
193+
"Required dependencies (PIL, torch) are missing for ColPali image embedding."
194+
) from e
195+
196+
colpali_model, processor, expected_dim = get_colpali_model(model)
197+
pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
198+
inputs = processor.process_images([pil_image])
199+
with torch.no_grad():
200+
embeddings = colpali_model(**inputs)
201+
202+
# Return multi-vector format: [patches, hidden_dim]
203+
if len(embeddings.shape) != 3:
204+
raise ValueError(
205+
f"Expected 3D tensor [batch, patches, hidden_dim], got shape {embeddings.shape}"
206+
)
207+
208+
# Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
209+
patch_embeddings = embeddings[0] # Remove batch dimension
210+
211+
# Convert to list of lists: [[patch1_embedding], [patch2_embedding], ...]
212+
result = []
213+
for patch in patch_embeddings:
214+
result.append(patch.cpu().numpy().tolist())
215+
216+
return result
217+
218+
219+
def colpali_embed_query(query: str, model: str) -> list[list[float]]:
220+
"""Embed query using ColPali model, returning multi-vector format."""
221+
try:
222+
import torch
223+
import numpy as np
224+
except ImportError as e:
225+
raise ImportError(
226+
"Required dependencies (torch, numpy) are missing for ColPali query embedding."
227+
) from e
228+
229+
colpali_model, processor, target_dimension = get_colpali_model(model)
230+
inputs = processor.process_queries([query])
231+
with torch.no_grad():
232+
embeddings = colpali_model(**inputs)
233+
234+
# Return multi-vector format: [tokens, hidden_dim]
235+
if len(embeddings.shape) != 3:
236+
raise ValueError(
237+
f"Expected 3D tensor [batch, tokens, hidden_dim], got shape {embeddings.shape}"
238+
)
239+
240+
# Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
241+
token_embeddings = embeddings[0] # Remove batch dimension
242+
243+
# Convert to list of lists: [[token1_embedding], [token2_embedding], ...]
244+
result = []
245+
for token in token_embeddings:
246+
result.append(token.cpu().numpy().tolist())
247+
248+
return result

0 commit comments

Comments
 (0)