|
| 1 | +import datetime |
| 2 | +import os |
| 3 | +from contextlib import asynccontextmanager |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +import cocoindex |
| 7 | +from dotenv import load_dotenv |
| 8 | +from fastapi import FastAPI, Query |
| 9 | +from fastapi.middleware.cors import CORSMiddleware |
| 10 | +from fastapi.staticfiles import StaticFiles |
| 11 | +from qdrant_client import QdrantClient |
| 12 | + |
| 13 | + |
| 14 | +# --- Config --- |
| 15 | + |
| 16 | +# Use GRPC |
| 17 | +QDRANT_URL = os.getenv("QDRANT_URL", "localhost:6334") |
| 18 | +PREFER_GRPC = os.getenv("QDRANT_PREFER_GRPC", "true").lower() == "true" |
| 19 | + |
| 20 | +# Use HTTP |
| 21 | +# QDRANT_URL = os.getenv("QDRANT_URL", "localhost:6333") |
| 22 | +# PREFER_GRPC = os.getenv("QDRANT_PREFER_GRPC", "false").lower() == "true" |
| 23 | + |
| 24 | +OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/") |
| 25 | +QDRANT_COLLECTION = "ImageSearchColpali" |
| 26 | +COLPALI_MODEL_NAME = os.getenv("COLPALI_MODEL", "vidore/colpali-v1.2") |
| 27 | +print(f"📐 Using ColPali model {COLPALI_MODEL_NAME}") |
| 28 | + |
| 29 | + |
| 30 | +# Create ColPali embedding function using the class-based pattern |
| 31 | +colpali_embed = cocoindex.functions.ColPaliEmbedImage(model=COLPALI_MODEL_NAME) |
| 32 | + |
| 33 | + |
| 34 | +@cocoindex.transform_flow() |
| 35 | +def text_to_colpali_embedding( |
| 36 | + text: cocoindex.DataSlice[str], |
| 37 | +) -> cocoindex.DataSlice[list[list[float]]]: |
| 38 | + """ |
| 39 | + Embed text using a ColPali model, returning multi-vector format. |
| 40 | + This is shared logic between indexing and querying, ensuring consistent embeddings. |
| 41 | + """ |
| 42 | + return text.transform( |
| 43 | + cocoindex.functions.ColPaliEmbedQuery(model=COLPALI_MODEL_NAME) |
| 44 | + ) |
| 45 | + |
| 46 | + |
| 47 | +@cocoindex.flow_def(name="ImageObjectEmbeddingColpali") |
| 48 | +def image_object_embedding_flow( |
| 49 | + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope |
| 50 | +) -> None: |
| 51 | + data_scope["images"] = flow_builder.add_source( |
| 52 | + cocoindex.sources.LocalFile( |
| 53 | + path="img", included_patterns=["*.jpg", "*.jpeg", "*.png"], binary=True |
| 54 | + ), |
| 55 | + refresh_interval=datetime.timedelta(minutes=1), |
| 56 | + ) |
| 57 | + img_embeddings = data_scope.add_collector() |
| 58 | + with data_scope["images"].row() as img: |
| 59 | + ollama_model_name = os.getenv("OLLAMA_MODEL") |
| 60 | + if ollama_model_name is not None: |
| 61 | + # If an Ollama model is specified, generate an image caption |
| 62 | + img["caption"] = flow_builder.transform( |
| 63 | + cocoindex.functions.ExtractByLlm( |
| 64 | + llm_spec=cocoindex.llm.LlmSpec( |
| 65 | + api_type=cocoindex.LlmApiType.OLLAMA, model=ollama_model_name |
| 66 | + ), |
| 67 | + instruction=( |
| 68 | + "Describe the image in one detailed sentence. " |
| 69 | + "Name all visible animal species, objects, and the main scene. " |
| 70 | + "Be specific about type, color, and notable features. " |
| 71 | + "Mention what each animal is doing." |
| 72 | + ), |
| 73 | + output_type=str, |
| 74 | + ), |
| 75 | + image=img["content"], |
| 76 | + ) |
| 77 | + img["embedding"] = img["content"].transform(colpali_embed) |
| 78 | + |
| 79 | + collect_fields = { |
| 80 | + "id": cocoindex.GeneratedField.UUID, |
| 81 | + "filename": img["filename"], |
| 82 | + "embedding": img["embedding"], |
| 83 | + } |
| 84 | + |
| 85 | + if ollama_model_name is not None: |
| 86 | + print(f"Using Ollama model '{ollama_model_name}' for captioning.") |
| 87 | + collect_fields["caption"] = img["caption"] |
| 88 | + else: |
| 89 | + print(f"No Ollama model '{ollama_model_name}' found — skipping captioning.") |
| 90 | + |
| 91 | + img_embeddings.collect(**collect_fields) |
| 92 | + |
| 93 | + img_embeddings.export( |
| 94 | + "img_embeddings", |
| 95 | + cocoindex.targets.Qdrant(collection_name=QDRANT_COLLECTION), |
| 96 | + primary_key_fields=["id"], |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +@asynccontextmanager |
| 101 | +async def lifespan(app: FastAPI) -> None: |
| 102 | + load_dotenv() |
| 103 | + cocoindex.init() |
| 104 | + image_object_embedding_flow.setup(report_to_stdout=True) |
| 105 | + |
| 106 | + app.state.qdrant_client = QdrantClient(url=QDRANT_URL, prefer_grpc=PREFER_GRPC) |
| 107 | + |
| 108 | + # Start updater |
| 109 | + app.state.live_updater = cocoindex.FlowLiveUpdater(image_object_embedding_flow) |
| 110 | + app.state.live_updater.start() |
| 111 | + |
| 112 | + yield |
| 113 | + |
| 114 | + |
| 115 | +# --- FastAPI app for web API --- |
| 116 | +app = FastAPI(lifespan=lifespan) |
| 117 | + |
| 118 | +app.add_middleware( |
| 119 | + CORSMiddleware, |
| 120 | + allow_origins=["*"], |
| 121 | + allow_credentials=True, |
| 122 | + allow_methods=["*"], |
| 123 | + allow_headers=["*"], |
| 124 | +) |
| 125 | +# Serve images from the 'img' directory at /img |
| 126 | +app.mount("/img", StaticFiles(directory="img"), name="img") |
| 127 | + |
| 128 | + |
| 129 | +# --- Search API --- |
| 130 | +@app.get("/search") |
| 131 | +def search( |
| 132 | + q: str = Query(..., description="Search query"), |
| 133 | + limit: int = Query(5, description="Number of results"), |
| 134 | +) -> Any: |
| 135 | + # Get the multi-vector embedding for the query |
| 136 | + query_embedding = text_to_colpali_embedding.eval(q) |
| 137 | + print( |
| 138 | + f"🔍 Query multi-vector shape: {len(query_embedding)} tokens x {len(query_embedding[0]) if query_embedding else 0} dims" |
| 139 | + ) |
| 140 | + |
| 141 | + # Search in Qdrant with multi-vector MaxSim scoring using query_points API |
| 142 | + search_results = app.state.qdrant_client.query_points( |
| 143 | + collection_name=QDRANT_COLLECTION, |
| 144 | + query=query_embedding, # Multi-vector format: list[list[float]] |
| 145 | + using="embedding", # Specify the vector field name |
| 146 | + limit=limit, |
| 147 | + with_payload=True, |
| 148 | + ) |
| 149 | + |
| 150 | + print(f"📈 Found {len(search_results.points)} results with MaxSim scoring") |
| 151 | + |
| 152 | + return { |
| 153 | + "results": [ |
| 154 | + { |
| 155 | + "filename": result.payload["filename"], |
| 156 | + "score": result.score, |
| 157 | + "caption": result.payload.get("caption"), |
| 158 | + } |
| 159 | + for result in search_results.points |
| 160 | + ] |
| 161 | + } |
0 commit comments