Skip to content

Added support for nullable cursor keys #8431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
<PackageVersion Include="Squadron.RabbitMQ" Version="0.25.0-preview.2" />
<PackageVersion Include="Squadron.RavenDB" Version="0.25.0-preview.2" />
<PackageVersion Include="Squadron.Redis" Version="0.25.0-preview.2" />
<PackageVersion Include="Squadron.SqlServer" Version="0.25.0-preview.2" />
<PackageVersion Include="Squadron.AzureStorage" Version="0.25.0-preview.2" />
<PackageVersion Include="StackExchange.Redis" Version="2.6.80" />
<PackageVersion Include="System.Collections.Immutable" Version="8.0.0" />
Expand All @@ -78,6 +79,7 @@
<PackageVersion Include="Microsoft.EntityFrameworkCore.InMemory" Version="10.0.0-preview.5.25277.114" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="10.0.0-preview.5.25277.114" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="10.0.0-preview.5.25277.114" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.0-preview.5.25277.114" />
<PackageVersion Include="Microsoft.Extensions.Caching.Memory" Version="10.0.0-preview.5.25277.114" />
<PackageVersion Include="Microsoft.Extensions.DependencyInjection" Version="10.0.0-preview.5.25277.114" />
<PackageVersion Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.0-preview.5.25277.114" />
Expand All @@ -101,6 +103,7 @@
<PackageVersion Include="Microsoft.EntityFrameworkCore.InMemory" Version="9.0.4" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="9.0.4" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="9.0.4" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="9.0.4" />
<PackageVersion Include="Microsoft.Extensions.Caching.Memory" Version="9.0.4" />
<PackageVersion Include="Microsoft.Extensions.DependencyInjection" Version="9.0.4" />
<PackageVersion Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="9.0.4" />
Expand All @@ -124,6 +127,7 @@
<PackageVersion Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.15" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="8.0.15" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="8.0.15" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="8.0.15" />
<PackageVersion Include="Microsoft.Extensions.Caching.Memory" Version="8.0.1" />
<PackageVersion Include="Microsoft.Extensions.DependencyInjection" Version="8.0.1" />
<PackageVersion Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="8.0.2" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ internal static class ExpressionHelpers
.GetMethod(nameof(CreateAndConvertParameter), BindingFlags.NonPublic | BindingFlags.Static)!;

private static readonly ConcurrentDictionary<Type, Func<object?, Expression>> s_cachedConverters = new();
private static readonly NullabilityInfoContext s_nullabilityInfoContext = new();
private static readonly Expression s_null = Expression.Constant(null);
private static readonly Expression s_false = Expression.Constant(false);
private static readonly Expression s_zero = Expression.Constant(0);

/// <summary>
/// Builds a where expression that can be used to slice a dataset.
Expand All @@ -28,6 +32,9 @@ internal static class ExpressionHelpers
/// <param name="forward">
/// Defines how the dataset is sorted.
/// </param>
/// <param name="nullOrdering">
/// Defines the null ordering to be used.
/// </param>
/// <typeparam name="T">
/// The entity type.
/// </typeparam>
Expand All @@ -43,7 +50,8 @@ internal static class ExpressionHelpers
public static (Expression<Func<T, bool>> WhereExpression, int Offset) BuildWhereExpression<T>(
ReadOnlySpan<CursorKey> keys,
Cursor cursor,
bool forward)
bool forward,
NullOrdering nullOrdering)
{
if (keys.Length == 0)
{
Expand All @@ -58,14 +66,16 @@ public static (Expression<Func<T, bool>> WhereExpression, int Offset) BuildWhere
var cursorExpr = new Expression[cursor.Values.Length];
for (var i = 0; i < cursor.Values.Length; i++)
{
cursorExpr[i] = CreateParameter(cursor.Values[i], keys[i].Expression.ReturnType);
var parameterType = Nullable.GetUnderlyingType(keys[i].Expression.ReturnType)
?? keys[i].Expression.ReturnType;

cursorExpr[i] = CreateParameter(cursor.Values[i], parameterType);
}

var handled = new List<CursorKey>();
Expression? expression = null;

var parameter = Expression.Parameter(typeof(T), "t");
var zero = Expression.Constant(0);

for (var i = 0; i < keys.Length; i++)
{
Expand All @@ -77,24 +87,50 @@ public static (Expression<Func<T, bool>> WhereExpression, int Offset) BuildWhere
{
var handledKey = handled[j];

keyExpr = Expression.Equal(
Expression.Call(ReplaceParameter(handledKey.Expression, parameter), handledKey.CompareMethod,
cursorExpr[j]), zero);
keyExpr = BuildEqualToKeyExpr(
handledKey,
parameter,
cursor.Values[j],
cursorExpr[j]);

current = current is null ? keyExpr : Expression.AndAlso(current, keyExpr);
}

var keyIsNullable = IsNullable(key.Expression);

if (keyIsNullable && nullOrdering == NullOrdering.Unspecified)
{
throw new Exception(
"The NullOrdering option must be specified in the paging options or " +
"arguments when using nullable keys.");
}

var greaterThan = forward
? key.Direction == CursorKeyDirection.Ascending
: key.Direction == CursorKeyDirection.Descending;

keyExpr = greaterThan
? Expression.GreaterThan(
Expression.Call(ReplaceParameter(key.Expression, parameter), key.CompareMethod, cursorExpr[i]),
zero)
: Expression.LessThan(
Expression.Call(ReplaceParameter(key.Expression, parameter), key.CompareMethod, cursorExpr[i]),
zero);
if (greaterThan)
{
keyExpr =
BuildGreaterThanKeyExpr(
key,
parameter,
cursor.Values[i],
keyIsNullable,
nullOrdering,
cursorExpr[i]);
}
else
{
keyExpr =
BuildLessThanKeyExpr(
key,
parameter,
cursor.Values[i],
keyIsNullable,
nullOrdering,
cursorExpr[i]);
}

current = current is null ? keyExpr : Expression.AndAlso(current, keyExpr);
expression = expression is null ? current : Expression.OrElse(expression, current);
Expand All @@ -104,6 +140,207 @@ public static (Expression<Func<T, bool>> WhereExpression, int Offset) BuildWhere
return (Expression.Lambda<Func<T, bool>>(expression!, parameter), cursor.Offset ?? 0);
}

private static Expression BuildEqualToKeyExpr(
CursorKey cursorKey,
ParameterExpression parameter,
object? cursorValue,
Expression cursorExpr)
{
var keyIsNullable = IsNullable(cursorKey.Expression);
var keyExpr = ReplaceParameter(cursorKey.Expression, parameter);

// Access the value of the key if it is a nullable value type.
var keyValueExpr = cursorKey.Expression.ReturnType.IsValueType && keyIsNullable
? Expression.Property(keyExpr, "Value")
: keyExpr;

if (keyIsNullable)
{
if (cursorValue is null)
{
// SQL: WHERE key IS NULL.
keyExpr = Expression.Equal(keyExpr, s_null);
}
else
{
// SQL: WHERE key = cursorValue.
keyExpr = Expression.Equal(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
}
}
else
{
// SQL: WHERE key = cursorValue.
keyExpr = Expression.Equal(
Expression.Call(keyExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
}

return keyExpr;
}

private static Expression BuildGreaterThanKeyExpr(
CursorKey cursorKey,
ParameterExpression parameter,
object? cursorValue,
bool keyIsNullable,
NullOrdering nullOrdering,
Expression cursorExpr)
{
var keyExpr = ReplaceParameter(cursorKey.Expression, parameter);

// Access the value of the key if it is a nullable value type.
var keyValueExpr =
cursorKey.Expression.ReturnType.IsValueType && keyIsNullable
? Expression.Property(keyExpr, "Value")
: keyExpr;

if (keyIsNullable)
{
if (cursorValue is null)
{
keyExpr = nullOrdering == NullOrdering.NativeNullsFirst
// With nulls first, any non-null value is greater than null.
// SQL: WHERE key IS NOT NULL.
? Expression.NotEqual(keyExpr, s_null)
// With nulls last, no value is greater than null.
// SQL: WHERE false.
: s_false;
}
else
{
if (nullOrdering == NullOrdering.NativeNullsFirst)
{
// SQL: WHERE key > cursorValue.
keyExpr = Expression.GreaterThan(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
}
else
{
// When nulls are last, null is greater than any non-null value.
// SQL: WHERE key > cursorValue OR key IS NULL.
keyExpr = Expression.Or(
Expression.GreaterThan(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero),
Expression.Equal(keyExpr, s_null));
}
}
}
else
{
// SQL: WHERE key > cursorValue.
keyExpr = Expression.GreaterThan(
Expression.Call(keyExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
}

return keyExpr;
}

private static Expression BuildLessThanKeyExpr(
CursorKey cursorKey,
ParameterExpression parameter,
object? cursorValue,
bool keyIsNullable,
NullOrdering nullOrdering,
Expression cursorExpr)
{
var keyExpr = ReplaceParameter(cursorKey.Expression, parameter);

// Access the value of the key if it is a nullable value type.
var keyValueExpr =
cursorKey.Expression.ReturnType.IsValueType && keyIsNullable
? Expression.Property(keyExpr, "Value")
: keyExpr;

if (keyIsNullable)
{
if (cursorValue is null)
{
keyExpr = nullOrdering == NullOrdering.NativeNullsFirst
// With nulls first, no value is less than null.
// SQL: WHERE false.
? s_false
// With nulls last, any non-null value is less than null.
// SQL: WHERE key IS NOT NULL.
: Expression.NotEqual(keyExpr, s_null);
}
else
{
if (nullOrdering == NullOrdering.NativeNullsFirst)
{
// With nulls first, null is less than any non-null value.
// SQL: WHERE key < cursorValue OR key IS NULL.
keyExpr = Expression.Or(
Expression.LessThan(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero),
Expression.Equal(keyExpr, s_null));
}
else
{
// SQL: WHERE key < cursorValue.
keyExpr = Expression.LessThan(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
}
}
}
else
{
// SQL: WHERE key < cursorValue.
keyExpr = Expression.LessThan(
Expression.Call(keyExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
}

return keyExpr;
}

private static bool IsNullable(LambdaExpression expression)
{
if (expression.ReturnType.IsValueType)
{
return Nullable.GetUnderlyingType(expression.ReturnType) is not null;
}

var propertyInfo = expression.Body switch
{
MemberExpression
{
Member: PropertyInfo p
} => p,
BinaryExpression
{
NodeType: ExpressionType.Coalesce,
Right: MemberExpression { Member: PropertyInfo p }
} => p,
_ => null
};

if (propertyInfo is not null)
{
var nullability = s_nullabilityInfoContext.Create(propertyInfo).ReadState;

switch (nullability)
{
case NullabilityState.Nullable:
return true;
case NullabilityState.NotNull:
return false;
case NullabilityState.Unknown:
break;
default:
throw new InvalidOperationException();
}
}

throw new Exception("The nullability of the cursor key could not be determined.");
}

/// <summary>
/// Build the select expression for a batch paging expression that uses grouping.
/// </summary>
Expand Down Expand Up @@ -184,7 +421,11 @@ public static BatchExpression<TK, TV> BuildBatchExpression<TK, TV>(
if (arguments.After is not null)
{
cursor = CursorParser.Parse(arguments.After, keys);
var (whereExpr, cursorOffset) = BuildWhereExpression<TV>(keys, cursor, forward: true);
var (whereExpr, cursorOffset) = BuildWhereExpression<TV>(
keys,
cursor,
forward: true,
arguments.NullOrdering);
source = Expression.Call(typeof(Enumerable), "Where", [typeof(TV)], source, whereExpr);
offset = cursorOffset;

Expand All @@ -204,7 +445,12 @@ public static BatchExpression<TK, TV> BuildBatchExpression<TK, TV>(
}

cursor = CursorParser.Parse(arguments.Before, keys);
var (whereExpr, cursorOffset) = BuildWhereExpression<TV>(keys, cursor, forward: false);
var (whereExpr, cursorOffset) = BuildWhereExpression<TV>(
keys,
cursor,
forward:
false,
arguments.NullOrdering);
source = Expression.Call(typeof(Enumerable), "Where", [typeof(TV)], source, whereExpr);
offset = cursorOffset;
}
Expand Down
Loading
Loading