Skip to content

Commit ae2ff8f

Browse files
committed
Fix AOT vs reflective behavior failures in AotUserRepositoryTests.
Closes #3951
1 parent 529eb72 commit ae2ff8f

File tree

12 files changed

+141
-129
lines changed

12 files changed

+141
-129
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/AotQueries.java

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,7 @@ record AotQueries(AotQuery result, AotQuery count) {
3939
/**
4040
* Derive a count query from the given query.
4141
*/
42-
public static AotQueries from(StringAotQuery query, @Nullable String countProjection,
43-
QueryEnhancerSelector selector) {
44-
return from(query, StringAotQuery::getQuery, countProjection, selector);
45-
}
46-
47-
/**
48-
* Derive a count query from the given query.
49-
*/
50-
public static <T extends AotQuery> AotQueries from(T query, Function<T, DeclaredQuery> queryMapper,
42+
public static <T extends AotQuery> AotQueries withDerivedCountQuery(T query, Function<T, DeclaredQuery> queryMapper,
5143
@Nullable String countProjection, QueryEnhancerSelector selector) {
5244

5345
DeclaredQuery underlyingQuery = queryMapper.apply(query);
@@ -56,8 +48,7 @@ public static <T extends AotQuery> AotQueries from(T query, Function<T, Declared
5648
String derivedCountQuery = queryEnhancer
5749
.createCountQueryFor(StringUtils.hasText(countProjection) ? countProjection : null);
5850

59-
DeclaredQuery countQuery = underlyingQuery.rewrite(derivedCountQuery);
60-
return new AotQueries(query, StringAotQuery.of(countQuery));
51+
return new AotQueries(query, StringAotQuery.of(underlyingQuery.rewrite(derivedCountQuery)));
6152
}
6253

6354
/**

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/AotQuery.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919

2020
import org.springframework.data.domain.Limit;
21+
import org.springframework.data.jpa.repository.query.EntityQuery;
2122
import org.springframework.data.jpa.repository.query.ParameterBinding;
2223

2324
/**
@@ -35,6 +36,10 @@ abstract class AotQuery {
3536
this.parameterBindings = parameterBindings;
3637
}
3738

39+
static boolean hasConstructorExpressionOrDefaultProjection(EntityQuery query) {
40+
return query.hasConstructorExpression() || query.isDefaultProjection();
41+
}
42+
3843
/**
3944
* @return whether the query is a {@link jakarta.persistence.EntityManager#createNativeQuery native} one.
4045
*/
@@ -89,4 +94,10 @@ public boolean hasExpression() {
8994
return false;
9095
}
9196

97+
/**
98+
* @return {@literal true} if query is expected to return the declared method type directly; {@literal false} if the
99+
* result requires projection post-processing. See also {@code NativeJpaQuery#getTypeToQueryFor}.
100+
*/
101+
public abstract boolean hasConstructorExpressionOrDefaultProjection();
102+
92103
}

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/JpaCodeBlocks.java

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ static class QueryBlockBuilder {
9696
private @Nullable Class<?> queryRewriter = QueryRewriter.IdentityQueryRewriter.class;
9797

9898
private QueryBlockBuilder(AotQueryMethodGenerationContext context, JpaQueryMethod queryMethod) {
99+
99100
this.context = context;
100101
this.queryMethod = queryMethod;
101102
this.queryVariableName = context.localVariable("query");
@@ -294,14 +295,16 @@ private CodeBlock applyLimits(boolean exists, String pageable) {
294295
return builder.build();
295296
}
296297

298+
if (queries != null && queries.result() instanceof StringAotQuery sq && sq.hasPagingExpression()) {
299+
return builder.build();
300+
}
301+
297302
String limit = context.getLimitParameterName();
298303

299304
if (StringUtils.hasText(limit)) {
300305
builder.beginControlFlow("if ($L.isLimited())", limit);
301306
builder.addStatement("$L.setMaxResults($L.max())", queryVariableName, limit);
302307
builder.endControlFlow();
303-
} else if (queries != null && queries.result().isLimited()) {
304-
builder.addStatement("$L.setMaxResults($L)", queryVariableName, queries.result().getLimit().max());
305308
}
306309

307310
if (StringUtils.hasText(pageable)) {
@@ -316,6 +319,20 @@ private CodeBlock applyLimits(boolean exists, String pageable) {
316319
builder.endControlFlow();
317320
}
318321

322+
if (queries.result().isLimited()) {
323+
324+
int max = queries.result().getLimit().max();
325+
326+
builder.beginControlFlow("if ($L.getMaxResults() != $T.MAX_VALUE)", queryVariableName, Integer.class);
327+
builder.beginControlFlow("if ($1L.getMaxResults() > $2L && $1L.getFirstResult() > 0)", queryVariableName, max);
328+
builder.addStatement("$1L.setFirstResult($1L.getFirstResult() - ($1L.getMaxResults() - $2L))",
329+
queryVariableName, max);
330+
builder.endControlFlow();
331+
builder.endControlFlow();
332+
333+
builder.addStatement("$1L.setMaxResults($2L)", queryVariableName, max);
334+
}
335+
319336
return builder.build();
320337
}
321338

@@ -484,11 +501,12 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
484501

485502
if (query instanceof NamedAotQuery nq) {
486503

487-
if (!count && returnedType.isProjecting() && returnedType.getReturnedType().isInterface()) {
488-
builder.addStatement("$T $L = this.$L.createNamedQuery($S)", Query.class, queryVariableName,
489-
context.fieldNameOf(EntityManager.class), nq.getName());
490-
return builder.build();
491-
} else if (queryReturnType != null) {
504+
if (!count && !nq.hasConstructorExpressionOrDefaultProjection() && returnedType.isProjecting()
505+
&& returnedType.getReturnedType().isInterface()) {
506+
queryReturnType = Tuple.class;
507+
}
508+
509+
if (queryReturnType != null) {
492510

493511
builder.addStatement("$T $L = this.$L.createNamedQuery($S, $T.class)", Query.class, queryVariableName,
494512
context.fieldNameOf(EntityManager.class), nq.getName(), queryReturnType);

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/JpaRepositoryContributor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ private Optional<Class<QueryEnhancerSelector>> getQueryEnhancerSelectorClass() {
183183

184184
MergedAnnotation<Query> query = MergedAnnotations.from(method).get(Query.class);
185185

186-
AotQueries aotQueries = queriesFactory.createQueries(getRepositoryInformation(), query, selector, queryMethod,
187-
returnedType);
186+
AotQueries aotQueries = queriesFactory.createQueries(getRepositoryInformation(), returnedType, selector, query,
187+
queryMethod);
188188

189189
// no KeysetScrolling for now.
190190
if (parameters.hasScrollPositionParameter()) {

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/NamedAotQuery.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
*/
1616
package org.springframework.data.jpa.repository.aot;
1717

18-
import java.util.List;
19-
2018
import org.springframework.data.jpa.repository.query.DeclaredQuery;
21-
import org.springframework.data.jpa.repository.query.ParameterBinding;
22-
import org.springframework.data.jpa.repository.query.PreprocessedQuery;
19+
import org.springframework.data.jpa.repository.query.EntityQuery;
2320

2421
/**
2522
* Value object to describe a named AOT query.
@@ -31,20 +28,20 @@ class NamedAotQuery extends AotQuery {
3128

3229
private final String name;
3330
private final DeclaredQuery query;
31+
private final boolean constructorExpressionOrDefaultProjection;
3432

35-
private NamedAotQuery(String name, DeclaredQuery queryString, List<ParameterBinding> parameterBindings) {
36-
super(parameterBindings);
33+
public NamedAotQuery(String name, EntityQuery entityQuery) {
34+
super(entityQuery.getParameterBindings());
3735
this.name = name;
38-
this.query = queryString;
36+
this.query = entityQuery.getQuery();
37+
this.constructorExpressionOrDefaultProjection = AotQuery.hasConstructorExpressionOrDefaultProjection(entityQuery);
3938
}
4039

4140
/**
4241
* Creates a new {@code NamedAotQuery}.
4342
*/
44-
public static NamedAotQuery named(String namedQuery, DeclaredQuery queryString) {
45-
46-
PreprocessedQuery parsed = PreprocessedQuery.parse(queryString);
47-
return new NamedAotQuery(namedQuery, queryString, parsed.getBindings());
43+
public static NamedAotQuery named(String namedQuery, EntityQuery query) {
44+
return new NamedAotQuery(namedQuery, query);
4845
}
4946

5047
public String getName() {
@@ -64,4 +61,9 @@ public boolean isNative() {
6461
return query.isNative();
6562
}
6663

64+
@Override
65+
public boolean hasConstructorExpressionOrDefaultProjection() {
66+
return constructorExpressionOrDefaultProjection;
67+
}
68+
6769
}

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/QueriesFactory.java

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -112,28 +112,28 @@ private NamedQueries getNamedQueries(@Nullable RepositoryConfigurationSource con
112112
* Creates the {@link AotQueries} used within a specific {@link JpaQueryMethod}.
113113
*
114114
* @param repositoryInformation
115-
* @param query
115+
* @param returnedType
116116
* @param selector
117+
* @param query
117118
* @param queryMethod
118-
* @param returnedType
119119
* @return
120120
*/
121-
public AotQueries createQueries(RepositoryInformation repositoryInformation, MergedAnnotation<Query> query,
122-
QueryEnhancerSelector selector, JpaQueryMethod queryMethod, ReturnedType returnedType) {
121+
public AotQueries createQueries(RepositoryInformation repositoryInformation, ReturnedType returnedType,
122+
QueryEnhancerSelector selector, MergedAnnotation<Query> query, JpaQueryMethod queryMethod) {
123123

124124
if (query.isPresent() && StringUtils.hasText(query.getString("value"))) {
125125
return buildStringQuery(repositoryInformation.getDomainType(), returnedType, selector, query, queryMethod);
126126
}
127127

128128
String queryName = queryMethod.getNamedQueryName();
129-
if (hasNamedQuery(queryName, returnedType)) {
129+
if (hasNamedQuery(returnedType, queryName)) {
130130
return buildNamedQuery(returnedType, selector, queryName, query, queryMethod);
131131
}
132132

133-
return buildPartTreeQuery(returnedType, repositoryInformation, query, queryMethod);
133+
return buildPartTreeQuery(repositoryInformation, returnedType, selector, query, queryMethod);
134134
}
135135

136-
private boolean hasNamedQuery(String queryName, ReturnedType returnedType) {
136+
private boolean hasNamedQuery(ReturnedType returnedType, String queryName) {
137137
return namedQueries.hasQuery(queryName) || getNamedQuery(returnedType, queryName) != null;
138138
}
139139

@@ -142,19 +142,15 @@ private AotQueries buildStringQuery(Class<?> domainType, ReturnedType returnedTy
142142

143143
UnaryOperator<String> operator = s -> s.replaceAll("#\\{#entityName}", domainType.getName());
144144
boolean isNative = query.getBoolean("nativeQuery");
145-
Function<String, StringAotQuery> queryFunction = isNative ? StringAotQuery::nativeQuery : StringAotQuery::jpqlQuery;
145+
Function<String, DeclaredQuery> queryFunction = isNative ? DeclaredQuery::nativeQuery : DeclaredQuery::jpqlQuery;
146146
queryFunction = operator.andThen(queryFunction);
147147

148148
String queryString = query.getString("value");
149149

150-
StringAotQuery aotStringQuery = queryFunction.apply(queryString);
150+
EntityQuery entityQuery = EntityQuery.create(queryFunction.apply(queryString), selector);
151+
StringAotQuery aotStringQuery = StringAotQuery.of(entityQuery);
151152
String countQuery = query.getString("countQuery");
152153

153-
EntityQuery entityQuery = EntityQuery.create(aotStringQuery.getQuery(), selector);
154-
if (entityQuery.hasConstructorExpression() || entityQuery.isDefaultProjection()) {
155-
aotStringQuery = aotStringQuery.withConstructorExpressionOrDefaultProjection();
156-
}
157-
158154
if (returnedType.isProjecting() && returnedType.hasInputProperties()
159155
&& !returnedType.getReturnedType().isInterface()) {
160156

@@ -174,38 +170,38 @@ public ReturnedType getReturnedType() {
174170
}
175171

176172
if (StringUtils.hasText(countQuery)) {
177-
return AotQueries.from(aotStringQuery, queryFunction.apply(countQuery));
173+
return AotQueries.from(aotStringQuery, StringAotQuery.of(queryFunction.apply(countQuery)));
178174
}
179175

180-
if (hasNamedQuery(queryMethod.getNamedCountQueryName(), returnedType)) {
176+
if (hasNamedQuery(returnedType, queryMethod.getNamedCountQueryName())) {
181177
return AotQueries.from(aotStringQuery,
182-
createNamedAotQuery(returnedType, queryMethod.getNamedCountQueryName(), queryMethod, isNative));
178+
createNamedAotQuery(returnedType, selector, queryMethod.getNamedCountQueryName(), queryMethod, isNative));
183179
}
184180

185181
String countProjection = query.getString("countProjection");
186-
return AotQueries.from(aotStringQuery, countProjection, selector);
182+
return AotQueries.withDerivedCountQuery(aotStringQuery, StringAotQuery::getQuery, countProjection, selector);
187183
}
188184

189-
private AotQueries buildNamedQuery(ReturnedType returnedType, QueryEnhancerSelector selector,
190-
String queryName, MergedAnnotation<Query> query, JpaQueryMethod queryMethod) {
185+
private AotQueries buildNamedQuery(ReturnedType returnedType, QueryEnhancerSelector selector, String queryName,
186+
MergedAnnotation<Query> query, JpaQueryMethod queryMethod) {
191187

192188
boolean nativeQuery = query.isPresent() && query.getBoolean("nativeQuery");
193-
AotQuery aotQuery = createNamedAotQuery(returnedType, queryName, queryMethod, nativeQuery);
194-
189+
AotQuery aotQuery = createNamedAotQuery(returnedType, selector, queryName, queryMethod, nativeQuery);
195190
String countQuery = query.isPresent() ? query.getString("countQuery") : null;
196191

197192
if (StringUtils.hasText(countQuery)) {
198193
return AotQueries.from(aotQuery,
199-
aotQuery.isNative() ? StringAotQuery.nativeQuery(countQuery) : StringAotQuery.jpqlQuery(countQuery));
194+
StringAotQuery
195+
.of(aotQuery.isNative() ? DeclaredQuery.nativeQuery(countQuery) : DeclaredQuery.jpqlQuery(countQuery)));
200196
}
201197

202-
if (hasNamedQuery(queryMethod.getNamedCountQueryName(), returnedType)) {
198+
if (hasNamedQuery(returnedType, queryMethod.getNamedCountQueryName())) {
203199
return AotQueries.from(aotQuery,
204-
createNamedAotQuery(returnedType, queryMethod.getNamedCountQueryName(), queryMethod, nativeQuery));
200+
createNamedAotQuery(returnedType, selector, queryMethod.getNamedCountQueryName(), queryMethod, nativeQuery));
205201
}
206202

207203
String countProjection = query.isPresent() ? query.getString("countProjection") : null;
208-
return AotQueries.from(aotQuery, it -> {
204+
return AotQueries.withDerivedCountQuery(aotQuery, it -> {
209205

210206
if (it instanceof StringAotQuery sq) {
211207
return sq.getQuery();
@@ -215,25 +211,26 @@ private AotQueries buildNamedQuery(ReturnedType returnedType, QueryEnhancerSelec
215211
}, countProjection, selector);
216212
}
217213

218-
private AotQuery createNamedAotQuery(ReturnedType returnedType, String queryName, JpaQueryMethod queryMethod,
219-
boolean isNative) {
214+
private AotQuery createNamedAotQuery(ReturnedType returnedType, QueryEnhancerSelector selector, String queryName,
215+
JpaQueryMethod queryMethod, boolean isNative) {
220216

221217
if (namedQueries.hasQuery(queryName)) {
222218

223219
String queryString = namedQueries.getQuery(queryName);
224-
return StringAotQuery.named(queryName,
225-
isNative ? DeclaredQuery.nativeQuery(queryString) : DeclaredQuery.jpqlQuery(queryString));
220+
221+
DeclaredQuery query = isNative ? DeclaredQuery.nativeQuery(queryString) : DeclaredQuery.jpqlQuery(queryString);
222+
return StringAotQuery.named(queryName, EntityQuery.create(query, selector));
226223
}
227224

228225
TypedQueryReference<?> namedQuery = getNamedQuery(returnedType, queryName);
229226

230227
Assert.state(namedQuery != null, "Native named query must not be null");
231228

232-
return createNamedAotQuery(namedQuery, queryMethod, isNative);
229+
return createNamedAotQuery(namedQuery, selector, isNative, queryMethod);
233230
}
234231

235-
private AotQuery createNamedAotQuery(TypedQueryReference<?> namedQuery, JpaQueryMethod queryMethod,
236-
boolean isNative) {
232+
private AotQuery createNamedAotQuery(TypedQueryReference<?> namedQuery, QueryEnhancerSelector selector,
233+
boolean isNative, JpaQueryMethod queryMethod) {
237234

238235
QueryExtractor queryExtractor = queryMethod.getQueryExtractor();
239236
String queryString = queryExtractor.extractQueryString(namedQuery);
@@ -244,8 +241,9 @@ private AotQuery createNamedAotQuery(TypedQueryReference<?> namedQuery, JpaQuery
244241

245242
Assert.hasText(queryString, () -> "Cannot extract Query from named query [%s]".formatted(namedQuery.getName()));
246243

247-
return NamedAotQuery.named(namedQuery.getName(),
248-
isNative ? DeclaredQuery.nativeQuery(queryString) : DeclaredQuery.jpqlQuery(queryString));
244+
DeclaredQuery query = isNative ? DeclaredQuery.nativeQuery(queryString) : DeclaredQuery.jpqlQuery(queryString);
245+
246+
return NamedAotQuery.named(namedQuery.getName(), EntityQuery.create(query, selector));
249247
}
250248

251249
private @Nullable TypedQueryReference<?> getNamedQuery(ReturnedType returnedType, String queryName) {
@@ -266,19 +264,20 @@ private AotQuery createNamedAotQuery(TypedQueryReference<?> namedQuery, JpaQuery
266264
return null;
267265
}
268266

269-
private AotQueries buildPartTreeQuery(ReturnedType returnedType, RepositoryInformation repositoryInformation,
267+
private AotQueries buildPartTreeQuery(RepositoryInformation repositoryInformation, ReturnedType returnedType,
268+
QueryEnhancerSelector selector,
270269
MergedAnnotation<Query> query, JpaQueryMethod queryMethod) {
271270

272271
PartTree partTree = new PartTree(queryMethod.getName(), repositoryInformation.getDomainType());
273272
AotQuery aotQuery = createQuery(partTree, returnedType, queryMethod.getParameters(), templates);
274273

275274
if (query.isPresent() && StringUtils.hasText(query.getString("countQuery"))) {
276-
return AotQueries.from(aotQuery, StringAotQuery.jpqlQuery(query.getString("countQuery")));
275+
return AotQueries.from(aotQuery, StringAotQuery.of(DeclaredQuery.jpqlQuery(query.getString("countQuery"))));
277276
}
278277

279-
if (hasNamedQuery(queryMethod.getNamedCountQueryName(), returnedType)) {
278+
if (hasNamedQuery(returnedType, queryMethod.getNamedCountQueryName())) {
280279
return AotQueries.from(aotQuery,
281-
createNamedAotQuery(returnedType, queryMethod.getNamedCountQueryName(), queryMethod, false));
280+
createNamedAotQuery(returnedType, selector, queryMethod.getNamedCountQueryName(), queryMethod, false));
282281
}
283282

284283
AotQuery partTreeCountQuery = createCountQuery(partTree, returnedType, queryMethod.getParameters(), templates);
@@ -318,19 +317,21 @@ private AotQuery createCountQuery(PartTree partTree, ReturnedType returnedType,
318317

319318
Class<?> result = queryForEntity ? returnedType.getDomainType() : null;
320319

321-
if (query instanceof StringAotQuery sq && sq.hasConstructorExpressionOrDefaultProjection()) {
322-
return result;
323-
}
324-
325320
if (returnedType.isProjecting()) {
326321

327322
if (returnedType.getReturnedType().isInterface()) {
323+
324+
if (query.hasConstructorExpressionOrDefaultProjection()) {
325+
return result;
326+
}
327+
328328
return Tuple.class;
329329
}
330330

331331
return returnedType.getReturnedType();
332332
}
333333

334+
334335
return result;
335336
}
336337

0 commit comments

Comments
 (0)