Skip to content

Commit e971120

Browse files
WaVEVtimgraham
authored andcommitted
add support for subqueries
Subquery, Exists, and QuerySet as a lookup value.
1 parent 3c24e10 commit e971120

File tree

7 files changed

+212
-191
lines changed

7 files changed

+212
-191
lines changed

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,6 @@ Congratulations, your project is ready to go!
146146
- `QuerySet.delete()` and `update()` do not support queries that span multiple
147147
collections.
148148

149-
- `Subquery`, `Exists`, and using a `QuerySet` in `QuerySet.annotate()` aren't
150-
supported.
151-
152149
- `DateTimeField` doesn't support microsecond precision, and correspondingly,
153150
`DurationField` stores milliseconds rather than microseconds.
154151

django_mongodb/compiler.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from django.db.models.lookups import IsNull
1414
from django.db.models.sql import compiler
1515
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
16+
from django.db.models.sql.datastructures import BaseTable
1617
from django.utils.functional import cached_property
1718
from pymongo import ASCENDING, DESCENDING
1819

@@ -25,12 +26,16 @@ class SQLCompiler(compiler.SQLCompiler):
2526

2627
query_class = MongoQuery
2728
GROUP_SEPARATOR = "___"
29+
PARENT_FIELD_TEMPLATE = "parent__field__{}"
2830

2931
def __init__(self, *args, **kwargs):
3032
super().__init__(*args, **kwargs)
3133
self.aggregation_pipeline = None
34+
# Map columns to their subquery indices.
35+
self.column_indices = {}
3236
# A list of OrderBy objects for this query.
3337
self.order_by_objs = None
38+
self.subqueries = []
3439

3540
def _unfold_column(self, col):
3641
"""
@@ -154,23 +159,40 @@ def _prepare_annotations_for_aggregation_pipeline(self, order_by):
154159
group.update(having_group)
155160
return group, replacements
156161

157-
def _get_group_id_expressions(self, order_by):
158-
"""Generate group ID expressions for the aggregation pipeline."""
159-
group_expressions = set()
160-
replacements = {}
162+
def _get_group_expressions(self, order_by):
163+
if self.query.group_by is None:
164+
return []
165+
seen = set()
166+
expressions = set()
167+
if self.query.group_by is not True:
168+
# If group_by isn't True, then it's a list of expressions.
169+
for expr in self.query.group_by:
170+
if not hasattr(expr, "as_sql"):
171+
expr = self.query.resolve_ref(expr)
172+
if isinstance(expr, Ref):
173+
if expr.refs not in seen:
174+
seen.add(expr.refs)
175+
expressions.add(expr.source)
176+
else:
177+
expressions.add(expr)
178+
for expr, _, alias in self.select:
179+
# Skip members that are already grouped.
180+
if alias not in seen:
181+
expressions |= set(expr.get_group_by_cols())
161182
if not self._meta_ordering:
162183
for expr, (_, _, is_ref) in order_by:
184+
# Skip references.
163185
if not is_ref:
164-
group_expressions |= set(expr.get_group_by_cols())
165-
for expr, *_ in self.select:
166-
group_expressions |= set(expr.get_group_by_cols())
186+
expressions |= set(expr.get_group_by_cols())
167187
having_group_by = self.having.get_group_by_cols() if self.having else ()
168188
for expr in having_group_by:
169-
group_expressions.add(expr)
170-
if isinstance(self.query.group_by, tuple | list):
171-
group_expressions |= set(self.query.group_by)
172-
elif self.query.group_by is None:
173-
group_expressions = set()
189+
expressions.add(expr)
190+
return expressions
191+
192+
def _get_group_id_expressions(self, order_by):
193+
"""Generate group ID expressions for the aggregation pipeline."""
194+
replacements = {}
195+
group_expressions = self._get_group_expressions(order_by)
174196
if not group_expressions:
175197
ids = None
176198
else:
@@ -186,6 +208,8 @@ def _get_group_id_expressions(self, order_by):
186208
ids[alias] = Value(True).as_mql(self, self.connection)
187209
if replacement is not None:
188210
replacements[col] = replacement
211+
if isinstance(col, Ref):
212+
replacements[col.source] = replacement
189213
return ids, replacements
190214

191215
def _build_aggregation_pipeline(self, ids, group):
@@ -228,15 +252,15 @@ def pre_sql_setup(self, with_col_aliases=False):
228252
all_replacements.update(replacements)
229253
pipeline = self._build_aggregation_pipeline(ids, group)
230254
if self.having:
231-
pipeline.append(
232-
{
233-
"$match": {
234-
"$expr": self.having.replace_expressions(all_replacements).as_mql(
235-
self, self.connection
236-
)
237-
}
238-
}
255+
having = self.having.replace_expressions(all_replacements).as_mql(
256+
self, self.connection
239257
)
258+
# Add HAVING subqueries.
259+
for query in self.subqueries or ():
260+
pipeline.extend(query.get_pipeline())
261+
# Remove the added subqueries.
262+
self.subqueries = []
263+
pipeline.append({"$match": {"$expr": having}})
240264
self.aggregation_pipeline = pipeline
241265
self.annotations = {
242266
target: expr.replace_expressions(all_replacements)
@@ -388,6 +412,7 @@ def build_query(self, columns=None):
388412
query.mongo_query = {"$expr": expr}
389413
if extra_fields:
390414
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
415+
query.subqueries = self.subqueries
391416
return query
392417

393418
def get_columns(self):
@@ -431,7 +456,12 @@ def project_field(column):
431456

432457
@cached_property
433458
def collection_name(self):
434-
return self.query.get_meta().db_table
459+
base_table = next(
460+
v
461+
for k, v in self.query.alias_map.items()
462+
if isinstance(v, BaseTable) and self.query.alias_refcount[k]
463+
)
464+
return base_table.table_alias or base_table.table_name
435465

436466
@cached_property
437467
def collection(self):
@@ -581,7 +611,7 @@ def _get_ordering(self):
581611
return tuple(fields), sort_ordering, tuple(extra_fields)
582612

583613
def get_where(self):
584-
return self.where
614+
return getattr(self, "where", self.query.where)
585615

586616
def explain_query(self):
587617
# Validate format (none supported) and options.
@@ -741,7 +771,7 @@ def build_query(self, columns=None):
741771
else None
742772
)
743773
subquery = compiler.build_query(columns)
744-
query.subquery = subquery
774+
query.subqueries = [subquery]
745775
return query
746776

747777
def _make_result(self, result, columns=None):

django_mongodb/expressions.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Case,
1010
Col,
1111
CombinedExpression,
12+
Exists,
1213
ExpressionWrapper,
1314
F,
1415
NegatedExpression,
@@ -50,6 +51,18 @@ def case(self, compiler, connection):
5051

5152

5253
def col(self, compiler, connection): # noqa: ARG001
54+
# If the column is part of a subquery and belongs to one of the parent
55+
# queries, it will be stored for reference using $let in a $lookup stage.
56+
if (
57+
self.alias not in compiler.query.alias_refcount
58+
or compiler.query.alias_refcount[self.alias] == 0
59+
):
60+
try:
61+
index = compiler.column_indices[self]
62+
except KeyError:
63+
index = len(compiler.column_indices)
64+
compiler.column_indices[self] = index
65+
return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}"
5366
# Add the column's collection's alias for columns in joined collections.
5467
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
5568
return f"${prefix}{self.target.column}"
@@ -79,8 +92,73 @@ def order_by(self, compiler, connection):
7992
return self.expression.as_mql(compiler, connection)
8093

8194

82-
def query(self, compiler, connection): # noqa: ARG001
83-
raise NotSupportedError("Using a QuerySet in annotate() is not supported on MongoDB.")
95+
def query(self, compiler, connection, lookup_name=None):
96+
subquery_compiler = self.get_compiler(connection=connection)
97+
subquery_compiler.pre_sql_setup(with_col_aliases=False)
98+
columns = subquery_compiler.get_columns()
99+
field_name, expr = columns[0]
100+
subquery = subquery_compiler.build_query(
101+
columns
102+
if subquery_compiler.query.annotations or not subquery_compiler.query.default_cols
103+
else None
104+
)
105+
table_output = f"__subquery{len(compiler.subqueries)}"
106+
from_table = next(
107+
e.table_name for alias, e in self.alias_map.items() if self.alias_refcount[alias]
108+
)
109+
# To perform a subquery, a $lookup stage that escapsulates the entire
110+
# subquery pipeline is added. The "let" clause defines the variables
111+
# needed to bridge the main collection with the subquery.
112+
subquery.subquery_lookup = {
113+
"as": table_output,
114+
"from": from_table,
115+
"let": {
116+
compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection)
117+
for col, i in subquery_compiler.column_indices.items()
118+
},
119+
}
120+
# The result must be a list of values. The output is compressed with an
121+
# aggregation pipeline.
122+
if lookup_name in ("in", "range"):
123+
if subquery.aggregation_pipeline is None:
124+
subquery.aggregation_pipeline = []
125+
subquery.aggregation_pipeline.extend(
126+
[
127+
{
128+
"$facet": {
129+
"group": [
130+
{
131+
"$group": {
132+
"_id": None,
133+
"tmp_name": {
134+
"$addToSet": expr.as_mql(subquery_compiler, connection)
135+
},
136+
}
137+
}
138+
]
139+
}
140+
},
141+
{
142+
"$project": {
143+
field_name: {
144+
"$ifNull": [
145+
{
146+
"$getField": {
147+
"input": {"$arrayElemAt": ["$group", 0]},
148+
"field": "tmp_name",
149+
}
150+
},
151+
[],
152+
]
153+
}
154+
}
155+
},
156+
]
157+
)
158+
# Erase project_fields since the required value is projected above.
159+
subquery.project_fields = None
160+
compiler.subqueries.append(subquery)
161+
return f"${table_output}.{field_name}"
84162

85163

86164
def raw_sql(self, compiler, connection): # noqa: ARG001
@@ -100,8 +178,16 @@ def star(self, compiler, connection): # noqa: ARG001
100178
return {"$literal": True}
101179

102180

103-
def subquery(self, compiler, connection): # noqa: ARG001
104-
raise NotSupportedError(f"{self.__class__.__name__} is not supported on MongoDB.")
181+
def subquery(self, compiler, connection, lookup_name=None):
182+
return self.query.as_mql(compiler, connection, lookup_name=lookup_name)
183+
184+
185+
def exists(self, compiler, connection, lookup_name=None):
186+
try:
187+
lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name)
188+
except EmptyResultSet:
189+
return Value(False).as_mql(compiler, connection)
190+
return connection.mongo_operators["isnull"](lhs_mql, False)
105191

106192

107193
def when(self, compiler, connection):
@@ -130,6 +216,7 @@ def register_expressions():
130216
Case.as_mql = case
131217
Col.as_mql = col
132218
CombinedExpression.as_mql = combined_expression
219+
Exists.as_mql = exists
133220
ExpressionWrapper.as_mql = expression_wrapper
134221
F.as_mql = f
135222
NegatedExpression.as_mql = negated_expression

0 commit comments

Comments
 (0)