13
13
from django .db .models .lookups import IsNull
14
14
from django .db .models .sql import compiler
15
15
from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , SINGLE
16
+ from django .db .models .sql .datastructures import BaseTable
16
17
from django .utils .functional import cached_property
17
18
from pymongo import ASCENDING , DESCENDING
18
19
@@ -25,12 +26,16 @@ class SQLCompiler(compiler.SQLCompiler):
25
26
26
27
query_class = MongoQuery
27
28
GROUP_SEPARATOR = "___"
29
+ PARENT_FIELD_TEMPLATE = "parent__field__{}"
28
30
29
31
def __init__ (self , * args , ** kwargs ):
30
32
super ().__init__ (* args , ** kwargs )
31
33
self .aggregation_pipeline = None
34
+ # Map columns to their subquery indices.
35
+ self .column_indices = {}
32
36
# A list of OrderBy objects for this query.
33
37
self .order_by_objs = None
38
+ self .subqueries = []
34
39
35
40
def _unfold_column (self , col ):
36
41
"""
@@ -154,23 +159,40 @@ def _prepare_annotations_for_aggregation_pipeline(self, order_by):
154
159
group .update (having_group )
155
160
return group , replacements
156
161
157
- def _get_group_id_expressions (self , order_by ):
158
- """Generate group ID expressions for the aggregation pipeline."""
159
- group_expressions = set ()
160
- replacements = {}
162
+ def _get_group_expressions (self , order_by ):
163
+ if self .query .group_by is None :
164
+ return []
165
+ seen = set ()
166
+ expressions = set ()
167
+ if self .query .group_by is not True :
168
+ # If group_by isn't True, then it's a list of expressions.
169
+ for expr in self .query .group_by :
170
+ if not hasattr (expr , "as_sql" ):
171
+ expr = self .query .resolve_ref (expr )
172
+ if isinstance (expr , Ref ):
173
+ if expr .refs not in seen :
174
+ seen .add (expr .refs )
175
+ expressions .add (expr .source )
176
+ else :
177
+ expressions .add (expr )
178
+ for expr , _ , alias in self .select :
179
+ # Skip members that are already grouped.
180
+ if alias not in seen :
181
+ expressions |= set (expr .get_group_by_cols ())
161
182
if not self ._meta_ordering :
162
183
for expr , (_ , _ , is_ref ) in order_by :
184
+ # Skip references.
163
185
if not is_ref :
164
- group_expressions |= set (expr .get_group_by_cols ())
165
- for expr , * _ in self .select :
166
- group_expressions |= set (expr .get_group_by_cols ())
186
+ expressions |= set (expr .get_group_by_cols ())
167
187
having_group_by = self .having .get_group_by_cols () if self .having else ()
168
188
for expr in having_group_by :
169
- group_expressions .add (expr )
170
- if isinstance (self .query .group_by , tuple | list ):
171
- group_expressions |= set (self .query .group_by )
172
- elif self .query .group_by is None :
173
- group_expressions = set ()
189
+ expressions .add (expr )
190
+ return expressions
191
+
192
+ def _get_group_id_expressions (self , order_by ):
193
+ """Generate group ID expressions for the aggregation pipeline."""
194
+ replacements = {}
195
+ group_expressions = self ._get_group_expressions (order_by )
174
196
if not group_expressions :
175
197
ids = None
176
198
else :
@@ -186,6 +208,8 @@ def _get_group_id_expressions(self, order_by):
186
208
ids [alias ] = Value (True ).as_mql (self , self .connection )
187
209
if replacement is not None :
188
210
replacements [col ] = replacement
211
+ if isinstance (col , Ref ):
212
+ replacements [col .source ] = replacement
189
213
return ids , replacements
190
214
191
215
def _build_aggregation_pipeline (self , ids , group ):
@@ -228,15 +252,15 @@ def pre_sql_setup(self, with_col_aliases=False):
228
252
all_replacements .update (replacements )
229
253
pipeline = self ._build_aggregation_pipeline (ids , group )
230
254
if self .having :
231
- pipeline .append (
232
- {
233
- "$match" : {
234
- "$expr" : self .having .replace_expressions (all_replacements ).as_mql (
235
- self , self .connection
236
- )
237
- }
238
- }
255
+ having = self .having .replace_expressions (all_replacements ).as_mql (
256
+ self , self .connection
239
257
)
258
+ # Add HAVING subqueries.
259
+ for query in self .subqueries or ():
260
+ pipeline .extend (query .get_pipeline ())
261
+ # Remove the added subqueries.
262
+ self .subqueries = []
263
+ pipeline .append ({"$match" : {"$expr" : having }})
240
264
self .aggregation_pipeline = pipeline
241
265
self .annotations = {
242
266
target : expr .replace_expressions (all_replacements )
@@ -388,6 +412,7 @@ def build_query(self, columns=None):
388
412
query .mongo_query = {"$expr" : expr }
389
413
if extra_fields :
390
414
query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
415
+ query .subqueries = self .subqueries
391
416
return query
392
417
393
418
def get_columns (self ):
@@ -431,7 +456,12 @@ def project_field(column):
431
456
432
457
@cached_property
433
458
def collection_name (self ):
434
- return self .query .get_meta ().db_table
459
+ base_table = next (
460
+ v
461
+ for k , v in self .query .alias_map .items ()
462
+ if isinstance (v , BaseTable ) and self .query .alias_refcount [k ]
463
+ )
464
+ return base_table .table_alias or base_table .table_name
435
465
436
466
@cached_property
437
467
def collection (self ):
@@ -581,7 +611,7 @@ def _get_ordering(self):
581
611
return tuple (fields ), sort_ordering , tuple (extra_fields )
582
612
583
613
def get_where (self ):
584
- return self . where
614
+ return getattr ( self , "where" , self . query . where )
585
615
586
616
def explain_query (self ):
587
617
# Validate format (none supported) and options.
@@ -741,7 +771,7 @@ def build_query(self, columns=None):
741
771
else None
742
772
)
743
773
subquery = compiler .build_query (columns )
744
- query .subquery = subquery
774
+ query .subqueries = [ subquery ]
745
775
return query
746
776
747
777
def _make_result (self , result , columns = None ):
0 commit comments