|
6 | 6 | from typing import Any, Literal
|
7 | 7 |
|
8 | 8 | import cocoindex
|
9 |
| -import torch |
| 9 | +import numpy as np |
10 | 10 | from dotenv import load_dotenv
|
11 |
| -from fastapi import FastAPI, Query |
| 11 | +from fastapi import FastAPI, Query, HTTPException |
12 | 12 | from fastapi.middleware.cors import CORSMiddleware
|
13 | 13 | from fastapi.staticfiles import StaticFiles
|
14 | 14 | from PIL import Image
|
15 | 15 | from qdrant_client import QdrantClient
|
16 |
| -from transformers import CLIPModel, CLIPProcessor |
| 16 | +from colpali_engine.models import ColPali, ColPaliProcessor |
| 17 | + |
| 18 | + |
| 19 | +# --- Config --- |
| 20 | + |
| 21 | +# Use GRPC |
| 22 | +QDRANT_URL = os.getenv("QDRANT_URL", "localhost:6334") |
| 23 | +PREFER_GRPC = os.getenv("QDRANT_PREFER_GRPC", "true").lower() == "true" |
| 24 | + |
| 25 | +# Use HTTP |
| 26 | +# QDRANT_URL = os.getenv("QDRANT_URL", "localhost:6333") |
| 27 | +# PREFER_GRPC = os.getenv("QDRANT_PREFER_GRPC", "false").lower() == "true" |
17 | 28 |
|
18 | 29 | OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/")
|
19 |
| -QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6334/") |
20 |
| -QDRANT_COLLECTION = "ImageSearch" |
21 |
| -CLIP_MODEL_NAME = "openai/clip-vit-large-patch14" |
22 |
| -CLIP_MODEL_DIMENSION = 768 |
| 30 | +QDRANT_COLLECTION = "ImageSearchColpali" |
| 31 | +COLPALI_MODEL_NAME = os.getenv("COLPALI_MODEL", "vidore/colpali-v1.2") |
| 32 | +COLPALI_MODEL_DIMENSION = 1031 # Set to match ColPali's output |
| 33 | + |
| 34 | +# --- ColPali model cache and embedding functions --- |
| 35 | +_colpali_model_cache = {} |
| 36 | + |
| 37 | + |
| 38 | +def get_colpali_model(model: str = COLPALI_MODEL_NAME): |
| 39 | + global _colpali_model_cache |
| 40 | + if model not in _colpali_model_cache: |
| 41 | + print(f"Loading ColPali model: {model}") |
| 42 | + _colpali_model_cache[model] = { |
| 43 | + "model": ColPali.from_pretrained(model), |
| 44 | + "processor": ColPaliProcessor.from_pretrained(model), |
| 45 | + } |
| 46 | + return _colpali_model_cache[model]["model"], _colpali_model_cache[model][ |
| 47 | + "processor" |
| 48 | + ] |
| 49 | + |
23 | 50 |
|
| 51 | +def colpali_embed_image( |
| 52 | + img_bytes: bytes, model: str = COLPALI_MODEL_NAME |
| 53 | +) -> list[float]: |
| 54 | + from PIL import Image |
| 55 | + import torch |
| 56 | + import io |
24 | 57 |
|
25 |
| -@functools.cache |
26 |
| -def get_clip_model() -> tuple[CLIPModel, CLIPProcessor]: |
27 |
| - model = CLIPModel.from_pretrained(CLIP_MODEL_NAME) |
28 |
| - processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME) |
29 |
| - return model, processor |
| 58 | + colpali_model, processor = get_colpali_model(model) |
| 59 | + pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
| 60 | + inputs = processor.process_images([pil_image]) |
| 61 | + with torch.no_grad(): |
| 62 | + embeddings = colpali_model(**inputs) |
| 63 | + pooled_embedding = embeddings.mean(dim=-1) |
| 64 | + result = pooled_embedding[0].cpu().numpy() # [1031] |
| 65 | + return result.tolist() |
| 66 | + |
| 67 | + |
| 68 | +def colpali_embed_query(query: str, model: str = COLPALI_MODEL_NAME) -> list[float]: |
| 69 | + import torch |
| 70 | + import numpy as np |
| 71 | + |
| 72 | + colpali_model, processor = get_colpali_model(model) |
| 73 | + inputs = processor.process_queries([query]) |
| 74 | + with torch.no_grad(): |
| 75 | + embeddings = colpali_model(**inputs) |
| 76 | + pooled_embedding = embeddings.mean(dim=-1) |
| 77 | + query_tokens = pooled_embedding[0].cpu().numpy() # [15] |
| 78 | + target_length = COLPALI_MODEL_DIMENSION |
| 79 | + result = np.zeros(target_length, dtype=np.float32) |
| 80 | + result[: min(len(query_tokens), target_length)] = query_tokens[:target_length] |
| 81 | + return result.tolist() |
| 82 | + |
| 83 | + |
| 84 | +# --- End ColPali embedding functions --- |
30 | 85 |
|
31 | 86 |
|
32 | 87 | def embed_query(text: str) -> list[float]:
|
33 | 88 | """
|
34 |
| - Embed the caption using CLIP model. |
| 89 | + Embed the caption using ColPali model. |
35 | 90 | """
|
36 |
| - model, processor = get_clip_model() |
37 |
| - inputs = processor(text=[text], return_tensors="pt", padding=True) |
38 |
| - with torch.no_grad(): |
39 |
| - features = model.get_text_features(**inputs) |
40 |
| - return features[0].tolist() |
| 91 | + return colpali_embed_query(text, model=COLPALI_MODEL_NAME) |
41 | 92 |
|
42 | 93 |
|
43 | 94 | @cocoindex.op.function(cache=True, behavior_version=1, gpu=True)
|
44 | 95 | def embed_image(
|
45 | 96 | img_bytes: bytes,
|
46 |
| -) -> cocoindex.Vector[cocoindex.Float32, Literal[CLIP_MODEL_DIMENSION]]: |
| 97 | +) -> cocoindex.Vector[cocoindex.Float32, Literal[COLPALI_MODEL_DIMENSION]]: |
47 | 98 | """
|
48 |
| - Convert image to embedding using CLIP model. |
| 99 | + Convert image to embedding using ColPali model. |
49 | 100 | """
|
50 |
| - model, processor = get_clip_model() |
51 |
| - image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
52 |
| - inputs = processor(images=image, return_tensors="pt") |
53 |
| - with torch.no_grad(): |
54 |
| - features = model.get_image_features(**inputs) |
55 |
| - return features[0].tolist() |
| 101 | + return colpali_embed_image(img_bytes, model=COLPALI_MODEL_NAME) |
56 | 102 |
|
57 | 103 |
|
58 |
| -# CocoIndex flow: Ingest images, extract captions, embed, export to Qdrant |
59 |
| -@cocoindex.flow_def(name="ImageObjectEmbedding") |
| 104 | +@cocoindex.flow_def(name="ImageObjectEmbeddingColpali") |
60 | 105 | def image_object_embedding_flow(
|
61 | 106 | flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
|
62 | 107 | ) -> None:
|
63 | 108 | data_scope["images"] = flow_builder.add_source(
|
64 | 109 | cocoindex.sources.LocalFile(
|
65 | 110 | path="img", included_patterns=["*.jpg", "*.jpeg", "*.png"], binary=True
|
66 | 111 | ),
|
67 |
| - refresh_interval=datetime.timedelta( |
68 |
| - minutes=1 |
69 |
| - ), # Poll for changes every 1 minute |
| 112 | + refresh_interval=datetime.timedelta(minutes=1), |
70 | 113 | )
|
71 | 114 | img_embeddings = data_scope.add_collector()
|
72 | 115 | with data_scope["images"].row() as img:
|
@@ -117,7 +160,7 @@ async def lifespan(app: FastAPI) -> None:
|
117 | 160 | cocoindex.init()
|
118 | 161 | image_object_embedding_flow.setup(report_to_stdout=True)
|
119 | 162 |
|
120 |
| - app.state.qdrant_client = QdrantClient(url=QDRANT_URL, prefer_grpc=True) |
| 163 | + app.state.qdrant_client = QdrantClient(url=QDRANT_URL, prefer_grpc=PREFER_GRPC) |
121 | 164 |
|
122 | 165 | # Start updater
|
123 | 166 | app.state.live_updater = cocoindex.FlowLiveUpdater(image_object_embedding_flow)
|
@@ -162,9 +205,7 @@ def search(
|
162 | 205 | {
|
163 | 206 | "filename": result.payload["filename"],
|
164 | 207 | "score": result.score,
|
165 |
| - "caption": result.payload.get( |
166 |
| - "caption" |
167 |
| - ), # Include caption if available |
| 208 | + "caption": result.payload.get("caption"), |
168 | 209 | }
|
169 | 210 | for result in search_results
|
170 | 211 | ]
|
|
0 commit comments