Skip to content

Commit aa820bd

Browse files
committed
Move functon to functions.py and use multi-vector search
1 parent fb04748 commit aa820bd

File tree

4 files changed

+162
-67
lines changed

4 files changed

+162
-67
lines changed

examples/image_search_colpali/README.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ We appreciate a star ⭐ at [CocoIndex Github](https://github.com/cocoindex-io/c
1010

1111
## Technologies
1212
- CocoIndex for ETL and live update
13-
- **ColPali** - Multimodal Embeddings Model for images and query
14-
- Qdrant for Vector Storage (supports both gRPC and HTTP)
13+
- **ColPali** - Multimodal Embeddings Model for images and query with multi-vector late interaction
14+
- Qdrant for Vector Storage with multi-vector support and MaxSim scoring (supports both gRPC and HTTP)
1515
- FastAPI for backend
1616
- Ollama (Optional) for generating image captions using `gemma3` or other models
1717

@@ -55,6 +55,17 @@ export OLLAMA_MODEL="gemma3" # Optional, for caption generation
5555
```
5656
pip install -e .
5757
```
58+
Note: ColPali embedding support is included in the cocoindex library with the `[embeddings]` extra.
59+
60+
- The app automatically detects the ColPali model dimension and uses multi-vector embeddings with MaxSim scoring for optimal search performance.
61+
62+
## Supported Models
63+
- Default: `vidore/colpali-v1.2` (128-dimensional embeddings)
64+
- Also supports: `vidore/colpali-v1.1`, `vidore/colpali-v1.3`, and other ColPali variants
65+
- To use a different model, set the environment variable:
66+
```sh
67+
export COLPALI_MODEL="vidore/colpali-v1.3"
68+
```
5869

5970
- Run Backend
6071
```

examples/image_search_colpali/main.py

Lines changed: 20 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import io
44
import os
5+
import typing
56
from contextlib import asynccontextmanager
67
from typing import Any, Literal
78

@@ -13,7 +14,6 @@
1314
from fastapi.staticfiles import StaticFiles
1415
from PIL import Image
1516
from qdrant_client import QdrantClient
16-
from colpali_engine.models import ColPali, ColPaliProcessor
1717

1818

1919
# --- Config ---
@@ -29,76 +29,28 @@
2929
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/")
3030
QDRANT_COLLECTION = "ImageSearchColpali"
3131
COLPALI_MODEL_NAME = os.getenv("COLPALI_MODEL", "vidore/colpali-v1.2")
32-
COLPALI_MODEL_DIMENSION = 1031 # Set to match ColPali's output
32+
# Get ColPali embedding dimension dynamically from model
33+
HIDDEN_DIM = cocoindex.functions.get_colpali_dimension(COLPALI_MODEL_NAME)
34+
print(f"📐 Using ColPali model {COLPALI_MODEL_NAME} with {HIDDEN_DIM} hidden dimensions")
3335

34-
# --- ColPali model cache and embedding functions ---
35-
_colpali_model_cache = {}
3636

3737

38-
def get_colpali_model(model: str = COLPALI_MODEL_NAME):
39-
global _colpali_model_cache
40-
if model not in _colpali_model_cache:
41-
print(f"Loading ColPali model: {model}")
42-
_colpali_model_cache[model] = {
43-
"model": ColPali.from_pretrained(model),
44-
"processor": ColPaliProcessor.from_pretrained(model),
45-
}
46-
return _colpali_model_cache[model]["model"], _colpali_model_cache[model][
47-
"processor"
48-
]
49-
50-
51-
def colpali_embed_image(
52-
img_bytes: bytes, model: str = COLPALI_MODEL_NAME
53-
) -> list[float]:
54-
from PIL import Image
55-
import torch
56-
import io
57-
58-
colpali_model, processor = get_colpali_model(model)
59-
pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
60-
inputs = processor.process_images([pil_image])
61-
with torch.no_grad():
62-
embeddings = colpali_model(**inputs)
63-
pooled_embedding = embeddings.mean(dim=-1)
64-
result = pooled_embedding[0].cpu().numpy() # [1031]
65-
return result.tolist()
66-
67-
68-
def colpali_embed_query(query: str, model: str = COLPALI_MODEL_NAME) -> list[float]:
69-
import torch
70-
import numpy as np
71-
72-
colpali_model, processor = get_colpali_model(model)
73-
inputs = processor.process_queries([query])
74-
with torch.no_grad():
75-
embeddings = colpali_model(**inputs)
76-
pooled_embedding = embeddings.mean(dim=-1)
77-
query_tokens = pooled_embedding[0].cpu().numpy() # [15]
78-
target_length = COLPALI_MODEL_DIMENSION
79-
result = np.zeros(target_length, dtype=np.float32)
80-
result[: min(len(query_tokens), target_length)] = query_tokens[:target_length]
81-
return result.tolist()
82-
83-
84-
# --- End ColPali embedding functions ---
85-
86-
87-
def embed_query(text: str) -> list[float]:
38+
def embed_query(text: str) -> list[list[float]]:
8839
"""
89-
Embed the caption using ColPali model.
40+
Embed the caption using ColPali model, returning multi-vector format.
9041
"""
91-
return colpali_embed_query(text, model=COLPALI_MODEL_NAME)
42+
return cocoindex.functions.colpali_embed_query(text, model=COLPALI_MODEL_NAME)
9243

9344

9445
@cocoindex.op.function(cache=True, behavior_version=1, gpu=True)
9546
def embed_image(
9647
img_bytes: bytes,
97-
) -> cocoindex.Vector[cocoindex.Float32, Literal[COLPALI_MODEL_DIMENSION]]:
48+
) -> cocoindex.Vector[cocoindex.Vector[cocoindex.Float32, typing.Literal[HIDDEN_DIM]]]:
9849
"""
99-
Convert image to embedding using ColPali model.
50+
Convert image to embedding using ColPali model, returning multi-vector format.
51+
Returns variable number of patches, each with model-specific dimensional embeddings.
10052
"""
101-
return colpali_embed_image(img_bytes, model=COLPALI_MODEL_NAME)
53+
return cocoindex.functions.colpali_embed_image(img_bytes, model=COLPALI_MODEL_NAME)
10254

10355

10456
@cocoindex.flow_def(name="ImageObjectEmbeddingColpali")
@@ -189,16 +141,20 @@ def search(
189141
q: str = Query(..., description="Search query"),
190142
limit: int = Query(5, description="Number of results"),
191143
) -> Any:
192-
# Get the embedding for the query
144+
# Get the multi-vector embedding for the query
193145
query_embedding = embed_query(q)
146+
print(f"🔍 Query multi-vector shape: {len(query_embedding)} tokens x {len(query_embedding[0]) if query_embedding else 0} dims")
194147

195-
# Search in Qdrant
196-
search_results = app.state.qdrant_client.search(
148+
# Search in Qdrant with multi-vector MaxSim scoring using query_points API
149+
search_results = app.state.qdrant_client.query_points(
197150
collection_name=QDRANT_COLLECTION,
198-
query_vector=("embedding", query_embedding),
151+
query=query_embedding, # Multi-vector format: list[list[float]]
152+
using="embedding", # Specify the vector field name
199153
limit=limit,
200154
with_payload=True,
201155
)
156+
157+
print(f"📈 Found {len(search_results.points)} results with MaxSim scoring")
202158

203159
return {
204160
"results": [
@@ -207,6 +163,6 @@ def search(
207163
"score": result.score,
208164
"caption": result.payload.get("caption"),
209165
}
210-
for result in search_results
166+
for result in search_results.points
211167
]
212168
}

examples/image_search_colpali/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ dependencies = [
1010
"torch>=2.0.0",
1111
"qdrant-client>=1.14.2",
1212
"uvicorn>=0.34.3",
13-
"colpali-engine>=0.1.0",
1413
"Pillow>=10.0.0",
1514
"numpy>=1.24.0",
1615
]

python/cocoindex/functions.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,132 @@ def __call__(self, text: str) -> NDArray[np.float32]:
9999
assert self._model is not None
100100
result: NDArray[np.float32] = self._model.encode(text, convert_to_numpy=True)
101101
return result
102+
103+
104+
# ColPali model cache for ColPali embedding functions
105+
_colpali_model_cache = {}
106+
107+
108+
def get_colpali_model(model: str):
109+
"""Get or load ColPali model and processor."""
110+
global _colpali_model_cache
111+
if model not in _colpali_model_cache:
112+
try:
113+
from colpali_engine.models import ColPali, ColPaliProcessor
114+
except ImportError as e:
115+
raise ImportError(
116+
"ColPali is not available. Make sure cocoindex is installed with ColPali support."
117+
) from e
118+
119+
model_instance = ColPali.from_pretrained(model)
120+
processor_instance = ColPaliProcessor.from_pretrained(model)
121+
122+
# Try to get dimension from FastEmbed API first
123+
output_dim = None
124+
try:
125+
from fastembed import LateInteractionMultimodalEmbedding
126+
127+
# Use the standard FastEmbed ColPali model for dimension detection
128+
# All ColPali variants should have the same embedding dimension
129+
standard_colpali_model = "Qdrant/colpali-v1.3-fp16"
130+
131+
# Try to find the model in FastEmbed's supported models
132+
supported_models = LateInteractionMultimodalEmbedding.list_supported_models()
133+
for supported_model in supported_models:
134+
if supported_model["model"] == standard_colpali_model:
135+
output_dim = supported_model["dim"]
136+
break
137+
138+
except Exception:
139+
# FastEmbed API failed, will fall back to model config
140+
pass
141+
142+
# Fallback to model config if FastEmbed API failed
143+
if output_dim is None:
144+
if hasattr(model_instance, 'config'):
145+
# Try different config attributes that might contain the hidden dimension
146+
if hasattr(model_instance.config, 'hidden_size'):
147+
output_dim = model_instance.config.hidden_size
148+
elif hasattr(model_instance.config, 'text_config') and hasattr(model_instance.config.text_config, 'hidden_size'):
149+
output_dim = model_instance.config.text_config.hidden_size
150+
elif hasattr(model_instance.config, 'vision_config') and hasattr(model_instance.config.vision_config, 'hidden_size'):
151+
output_dim = model_instance.config.vision_config.hidden_size
152+
else:
153+
raise ValueError(f"Could not find hidden_size in model config for {model}. Config attributes: {dir(model_instance.config)}")
154+
else:
155+
raise ValueError(f"Model {model} has no config attribute. Model attributes: {dir(model_instance)}")
156+
157+
_colpali_model_cache[model] = {
158+
"model": model_instance,
159+
"processor": processor_instance,
160+
"dimension": output_dim,
161+
}
162+
return _colpali_model_cache[model]["model"], _colpali_model_cache[model]["processor"], _colpali_model_cache[model]["dimension"]
163+
164+
165+
def get_colpali_dimension(model: str) -> int:
166+
"""Get the output dimension for a ColPali model."""
167+
_, _, dimension = get_colpali_model(model)
168+
return dimension
169+
170+
171+
def colpali_embed_image(img_bytes: bytes, model: str) -> list[list[float]]:
172+
"""Embed image using ColPali model, returning multi-vector format."""
173+
try:
174+
from PIL import Image
175+
import torch
176+
import io
177+
except ImportError as e:
178+
raise ImportError(
179+
"Required dependencies (PIL, torch) are missing for ColPali image embedding."
180+
) from e
181+
182+
colpali_model, processor, expected_dim = get_colpali_model(model)
183+
pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
184+
inputs = processor.process_images([pil_image])
185+
with torch.no_grad():
186+
embeddings = colpali_model(**inputs)
187+
188+
# Return multi-vector format: [patches, hidden_dim]
189+
if len(embeddings.shape) != 3:
190+
raise ValueError(f"Expected 3D tensor [batch, patches, hidden_dim], got shape {embeddings.shape}")
191+
192+
# Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
193+
patch_embeddings = embeddings[0] # Remove batch dimension
194+
195+
# Convert to list of lists: [[patch1_embedding], [patch2_embedding], ...]
196+
result = []
197+
for patch in patch_embeddings:
198+
result.append(patch.cpu().numpy().tolist())
199+
200+
return result
201+
202+
203+
def colpali_embed_query(query: str, model: str) -> list[list[float]]:
204+
"""Embed query using ColPali model, returning multi-vector format."""
205+
try:
206+
import torch
207+
import numpy as np
208+
except ImportError as e:
209+
raise ImportError(
210+
"Required dependencies (torch, numpy) are missing for ColPali query embedding."
211+
) from e
212+
213+
colpali_model, processor, target_dimension = get_colpali_model(model)
214+
inputs = processor.process_queries([query])
215+
with torch.no_grad():
216+
embeddings = colpali_model(**inputs)
217+
218+
# Return multi-vector format: [tokens, hidden_dim]
219+
if len(embeddings.shape) != 3:
220+
raise ValueError(f"Expected 3D tensor [batch, tokens, hidden_dim], got shape {embeddings.shape}")
221+
222+
# Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
223+
token_embeddings = embeddings[0] # Remove batch dimension
224+
225+
# Convert to list of lists: [[token1_embedding], [token2_embedding], ...]
226+
result = []
227+
for token in token_embeddings:
228+
result.append(token.cpu().numpy().tolist())
229+
230+
return result

0 commit comments

Comments
 (0)