|
1 | 1 | import datetime
|
2 |
| -import functools |
3 |
| -import io |
4 | 2 | import os
|
5 | 3 | from contextlib import asynccontextmanager
|
6 |
| -from typing import Any, Literal |
| 4 | +from typing import Any |
7 | 5 |
|
8 | 6 | import cocoindex
|
9 | 7 | import numpy as np
|
|
13 | 11 | from fastapi.staticfiles import StaticFiles
|
14 | 12 | from PIL import Image
|
15 | 13 | from qdrant_client import QdrantClient
|
16 |
| -from colpali_engine.models import ColPali, ColPaliProcessor |
17 | 14 |
|
18 | 15 |
|
19 | 16 | # --- Config ---
|
|
29 | 26 | OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/")
|
30 | 27 | QDRANT_COLLECTION = "ImageSearchColpali"
|
31 | 28 | COLPALI_MODEL_NAME = os.getenv("COLPALI_MODEL", "vidore/colpali-v1.2")
|
32 |
| -COLPALI_MODEL_DIMENSION = 1031 # Set to match ColPali's output |
| 29 | +print(f"📐 Using ColPali model {COLPALI_MODEL_NAME}") |
33 | 30 |
|
34 |
| -# --- ColPali model cache and embedding functions --- |
35 |
| -_colpali_model_cache = {} |
36 | 31 |
|
| 32 | +# Create ColPali embedding function using the class-based pattern |
| 33 | +colpali_embed = cocoindex.functions.ColPaliEmbedImage(model=COLPALI_MODEL_NAME) |
37 | 34 |
|
38 |
| -def get_colpali_model(model: str = COLPALI_MODEL_NAME): |
39 |
| - global _colpali_model_cache |
40 |
| - if model not in _colpali_model_cache: |
41 |
| - print(f"Loading ColPali model: {model}") |
42 |
| - _colpali_model_cache[model] = { |
43 |
| - "model": ColPali.from_pretrained(model), |
44 |
| - "processor": ColPaliProcessor.from_pretrained(model), |
45 |
| - } |
46 |
| - return _colpali_model_cache[model]["model"], _colpali_model_cache[model][ |
47 |
| - "processor" |
48 |
| - ] |
49 |
| - |
50 |
| - |
51 |
| -def colpali_embed_image( |
52 |
| - img_bytes: bytes, model: str = COLPALI_MODEL_NAME |
53 |
| -) -> list[float]: |
54 |
| - from PIL import Image |
55 |
| - import torch |
56 |
| - import io |
57 |
| - |
58 |
| - colpali_model, processor = get_colpali_model(model) |
59 |
| - pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
60 |
| - inputs = processor.process_images([pil_image]) |
61 |
| - with torch.no_grad(): |
62 |
| - embeddings = colpali_model(**inputs) |
63 |
| - pooled_embedding = embeddings.mean(dim=-1) |
64 |
| - result = pooled_embedding[0].cpu().numpy() # [1031] |
65 |
| - return result.tolist() |
66 |
| - |
67 |
| - |
68 |
| -def colpali_embed_query(query: str, model: str = COLPALI_MODEL_NAME) -> list[float]: |
69 |
| - import torch |
70 |
| - import numpy as np |
71 |
| - |
72 |
| - colpali_model, processor = get_colpali_model(model) |
73 |
| - inputs = processor.process_queries([query]) |
74 |
| - with torch.no_grad(): |
75 |
| - embeddings = colpali_model(**inputs) |
76 |
| - pooled_embedding = embeddings.mean(dim=-1) |
77 |
| - query_tokens = pooled_embedding[0].cpu().numpy() # [15] |
78 |
| - target_length = COLPALI_MODEL_DIMENSION |
79 |
| - result = np.zeros(target_length, dtype=np.float32) |
80 |
| - result[: min(len(query_tokens), target_length)] = query_tokens[:target_length] |
81 |
| - return result.tolist() |
82 |
| - |
83 |
| - |
84 |
| -# --- End ColPali embedding functions --- |
85 | 35 |
|
86 |
| - |
87 |
| -def embed_query(text: str) -> list[float]: |
88 |
| - """ |
89 |
| - Embed the caption using ColPali model. |
90 |
| - """ |
91 |
| - return colpali_embed_query(text, model=COLPALI_MODEL_NAME) |
92 |
| - |
93 |
| - |
94 |
| -@cocoindex.op.function(cache=True, behavior_version=1, gpu=True) |
95 |
| -def embed_image( |
96 |
| - img_bytes: bytes, |
97 |
| -) -> cocoindex.Vector[cocoindex.Float32, Literal[COLPALI_MODEL_DIMENSION]]: |
| 36 | +@cocoindex.transform_flow() |
| 37 | +def text_to_colpali_embedding( |
| 38 | + text: cocoindex.DataSlice[str], |
| 39 | +) -> cocoindex.DataSlice[list[list[float]]]: |
98 | 40 | """
|
99 |
| - Convert image to embedding using ColPali model. |
| 41 | + Embed text using a ColPali model, returning multi-vector format. |
| 42 | + This is shared logic between indexing and querying, ensuring consistent embeddings. |
100 | 43 | """
|
101 |
| - return colpali_embed_image(img_bytes, model=COLPALI_MODEL_NAME) |
| 44 | + return text.transform( |
| 45 | + cocoindex.functions.ColPaliEmbedQuery(model=COLPALI_MODEL_NAME) |
| 46 | + ) |
102 | 47 |
|
103 | 48 |
|
104 | 49 | @cocoindex.flow_def(name="ImageObjectEmbeddingColpali")
|
@@ -131,7 +76,7 @@ def image_object_embedding_flow(
|
131 | 76 | ),
|
132 | 77 | image=img["content"],
|
133 | 78 | )
|
134 |
| - img["embedding"] = img["content"].transform(embed_image) |
| 79 | + img["embedding"] = img["content"].transform(colpali_embed) |
135 | 80 |
|
136 | 81 | collect_fields = {
|
137 | 82 | "id": cocoindex.GeneratedField.UUID,
|
@@ -189,24 +134,30 @@ def search(
|
189 | 134 | q: str = Query(..., description="Search query"),
|
190 | 135 | limit: int = Query(5, description="Number of results"),
|
191 | 136 | ) -> Any:
|
192 |
| - # Get the embedding for the query |
193 |
| - query_embedding = embed_query(q) |
| 137 | + # Get the multi-vector embedding for the query |
| 138 | + query_embedding = text_to_colpali_embedding.eval(q) |
| 139 | + print( |
| 140 | + f"🔍 Query multi-vector shape: {len(query_embedding)} tokens x {len(query_embedding[0]) if query_embedding else 0} dims" |
| 141 | + ) |
194 | 142 |
|
195 |
| - # Search in Qdrant |
196 |
| - search_results = app.state.qdrant_client.search( |
| 143 | + # Search in Qdrant with multi-vector MaxSim scoring using query_points API |
| 144 | + search_results = app.state.qdrant_client.query_points( |
197 | 145 | collection_name=QDRANT_COLLECTION,
|
198 |
| - query_vector=("embedding", query_embedding), |
| 146 | + query=query_embedding, # Multi-vector format: list[list[float]] |
| 147 | + using="embedding", # Specify the vector field name |
199 | 148 | limit=limit,
|
200 | 149 | with_payload=True,
|
201 | 150 | )
|
202 | 151 |
|
| 152 | + print(f"📈 Found {len(search_results.points)} results with MaxSim scoring") |
| 153 | + |
203 | 154 | return {
|
204 | 155 | "results": [
|
205 | 156 | {
|
206 | 157 | "filename": result.payload["filename"],
|
207 | 158 | "score": result.score,
|
208 | 159 | "caption": result.payload.get("caption"),
|
209 | 160 | }
|
210 |
| - for result in search_results |
| 161 | + for result in search_results.points |
211 | 162 | ]
|
212 | 163 | }
|
0 commit comments