diff --git a/src/Caliburn.Micro.Platform.Tests/ActionMessageTests.cs b/src/Caliburn.Micro.Platform.Tests/ActionMessageTests.cs new file mode 100644 index 000000000..bed5f84eb --- /dev/null +++ b/src/Caliburn.Micro.Platform.Tests/ActionMessageTests.cs @@ -0,0 +1,381 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using System.Windows.Forms; +using Xunit; + +namespace Caliburn.Micro.Platform.Tests +{ + public class ActionMessageTests + { + [Fact] + public void GetMethodPicksOverLoadStructParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = 1 }); + am.Parameters.Add(new Parameter() { Value = 1 }); + var obj = new Overloads(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded(1, 1)); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadStructStringParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = 1 }); + am.Parameters.Add(new Parameter() { Value = 1 }); + am.Parameters.Add(new Parameter() { Value = "" }); + var obj = new Overloads(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded(1, 1, "")); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadStringParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = "" }); + am.Parameters.Add(new Parameter() { Value = "" }); + var obj = new Overloads(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded("", "")); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadStringStructParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = "" }); + am.Parameters.Add(new Parameter() { Value = "" }); + am.Parameters.Add(new Parameter() { Value = 1 }); + var obj = new Overloads(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded("", "", 1)); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadTwoDifferntParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = new Foo() }); + am.Parameters.Add(new Parameter() { Value = new Bar() }); + var obj = new Overloads(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded(new Foo(), new Bar())); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadTwoBaseParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = new Foo() }); + am.Parameters.Add(new Parameter() { Value = new Foo() }); + var obj = new Overloads(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded(new Foo(), new Foo())); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadTwoDerivedParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = new Bar() }); + am.Parameters.Add(new Parameter() { Value = new Bar() }); + var obj = new Overloads(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded(new Bar(), new Bar())); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadEnumParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = OverloadEnum.One }); + var obj = new Overloads(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded(OverloadEnum.One)); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadDerivedOnlyParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = new Bar() }); + am.Parameters.Add(new Parameter() { Value = new Bar() }); + var obj = new OverloadsDerivedOnly(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded(new Bar(), new Bar())); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadDerivedOnlyBaseParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = new Foo() }); + am.Parameters.Add(new Parameter() { Value = new Bar() }); + var obj = new OverloadsDerivedOnly(); + + MethodInfo expected = null; + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + + [Fact] + public void GetMethodPicksOverLoadBaseOnlyParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = new Bar() }); + am.Parameters.Add(new Parameter() { Value = new Bar() }); + var obj = new OverloadsBaseOnly(); + + var expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded(new Foo(), new Foo())); + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + [Fact] + public void GetMethodPicksOverLoadBaseOnlyBaseParameter() + { + //Arrange + var am = new ActionMessage(); + am.MethodName = "Overloaded"; + + am.Parameters.Add(new Parameter() { Value = new Foo() }); + am.Parameters.Add(new Parameter() { Value = new Bar() }); + var obj = new OverloadsBaseOnly(); + + MethodInfo expected = MethodInfoHelper.GetMethodInfo(o => o.Overloaded(new Foo(), new Foo())); + ; + + //Act + var result = ActionMessage.GetTargetMethod(am, obj); + + //Assert + Assert.Equal(result, expected); + } + + enum OverloadEnum + { + One,Two,Three + } + + class OverloadsBaseOnly + { + public void Overloaded(Foo f) + { + + } + + public void Overloaded(Foo f, Foo f2) + { + + } + } + + class OverloadsDerivedOnly + { + public void Overloaded(Bar b) + { + + } + + public void Overloaded(Bar b, Bar b2) + { + + } + } + + class Overloads + { + public void Overloaded(int i) + { + + } + + public void Overloaded(int i, int i2) + { + + } + + public void Overloaded(int i, int i2, string s) + { + + } + + public void Overloaded(string s) + { + + } + + public void Overloaded(string s, string s2) + { + + } + + public void Overloaded(string s, string s2, int i) + { + + } + + public void Overloaded(Foo F) + { + + } + + public void Overloaded(Bar B) + { + + } + + public void Overloaded(OverloadEnum E) + { + + } + + public void Overloaded(Foo s, Foo f) + { + + } + + public void Overloaded(Foo s, Bar b) + { + + } + + public void Overloaded(Bar s, Bar b) + { + + } + + public void Overloaded(Bar s, Foo f) + { + + } + } + + class Foo + { + + } + + class Bar : Foo + { + + } + } + + public static class MethodInfoHelper + { + public static MethodInfo GetMethodInfo(Expression> expression) + { + var member = expression.Body as MethodCallExpression; + + if (member != null) + return member.Method; + + throw new ArgumentException("Expression is not a method", "expression"); + } + + } +} diff --git a/src/Caliburn.Micro.Platform/ActionMessage.cs b/src/Caliburn.Micro.Platform/ActionMessage.cs index 8d7e043ef..18963e089 100644 --- a/src/Caliburn.Micro.Platform/ActionMessage.cs +++ b/src/Caliburn.Micro.Platform/ActionMessage.cs @@ -373,19 +373,7 @@ public override string ToString() { /// /// The matching method, if available. public static Func GetTargetMethod = (message, target) => { -#if WINDOWS_UWP - return (from method in target.GetType().GetRuntimeMethods() - where method.Name == message.MethodName - let methodParameters = method.GetParameters() - where message.Parameters.Count == methodParameters.Length - select method).FirstOrDefault(); -#else - return (from method in target.GetType().GetMethods() - where method.Name == message.MethodName - let methodParameters = method.GetParameters() - where message.Parameters.Count == methodParameters.Length - select method).FirstOrDefault(); -#endif + return GetMethodInfo(target.GetType(), message.MethodName, message); }; /// @@ -452,7 +440,7 @@ public override string ToString() { foreach (string possibleGuardName in possibleGuardNames) { matchingGuardName = possibleGuardName; - guard = GetMethodInfo(targetType, "get_" + matchingGuardName); + guard = GetMethodInfo(targetType, "get_" + matchingGuardName, context.Message); if (guard != null) break; } @@ -500,7 +488,7 @@ static MethodInfo TryFindGuardMethod(ActionExecutionContext context, IEnumerable MethodInfo guard = null; foreach (string possibleGuardName in possibleGuardNames) { - guard = GetMethodInfo(targetType, possibleGuardName); + guard = GetMethodInfo(targetType, possibleGuardName, context.Message); if (guard != null) break; } @@ -554,6 +542,55 @@ static MethodInfo GetMethodInfo(Type t, string methodName) return t.GetRuntimeMethods().SingleOrDefault(m => m.Name == methodName); #else return t.GetMethod(methodName); +#endif + } + + static MethodInfo GetMethodInfo(Type t, string methodName, ActionMessage message) + { +#if WINDOWS_UWP + var methods = (from method in t.GetRuntimeMethods() + where method.Name == methodName + let methodParameters = method.GetParameters() + where message.Parameters.Count == methodParameters.Length + && message.Parameters.OfType().Zip(methodParameters, + (parameter, info) => info.ParameterType.IsInstanceOfType(parameter.Value)).All(b => b) + select method); + + MethodInfo returnMethodInfo = null; + foreach (MethodInfo method in methods) + { + returnMethodInfo = method; + if (method.GetParameters().Zip(message.Parameters.OfType(), (info, parameter) => + parameter.Value.GetType().IsAssignableFrom(info.ParameterType) + ).All(b => b)) + { + break; + } + } + + return returnMethodInfo; +#else + var methods = (from method in t.GetMethods() + where method.Name == methodName + let methodParameters = method.GetParameters() + where message.Parameters.Count == methodParameters.Length + && message.Parameters.Zip(methodParameters, + (parameter, info) => info.ParameterType.IsInstanceOfType(parameter.Value)).All(b => b) + select method); + + MethodInfo returnMethodInfo = null; + foreach (MethodInfo method in methods) + { + returnMethodInfo = method; + if (method.GetParameters().Zip(message.Parameters, (info, parameter) => + parameter.Value.GetType().IsAssignableFrom(info.ParameterType) + ).All(b => b)) + { + break; + } + } + + return returnMethodInfo; #endif } } diff --git a/src/Caliburn.Micro.Platform/Platforms/Xamarin.Forms/ActionMessage.cs b/src/Caliburn.Micro.Platform/Platforms/Xamarin.Forms/ActionMessage.cs index 7d108f782..dfb8c8794 100644 --- a/src/Caliburn.Micro.Platform/Platforms/Xamarin.Forms/ActionMessage.cs +++ b/src/Caliburn.Micro.Platform/Platforms/Xamarin.Forms/ActionMessage.cs @@ -313,11 +313,7 @@ public override string ToString() /// The matching method, if available. public static Func GetTargetMethod = (message, target) => { - return (from method in target.GetType().GetRuntimeMethods() - where method.Name == message.MethodName - let methodParameters = method.GetParameters() - where message.Parameters.Count == methodParameters.Length - select method).FirstOrDefault(); + return GetMethodInfo(target.GetType(), message.MethodName, message); }; /// @@ -443,7 +439,7 @@ static MethodInfo TryFindGuardMethod(ActionExecutionContext context, IEnumerable MethodInfo guard = null; foreach (string possibleGuardName in possibleGuardNames) { - guard = GetMethodInfo(targetType, possibleGuardName); + guard = GetMethodInfo(targetType, possibleGuardName, context.Message); if (guard != null) break; } @@ -495,5 +491,30 @@ static MethodInfo TryFindGuardMethod(ActionExecutionContext context, IEnumerable static MethodInfo GetMethodInfo(Type t, string methodName) { return t.GetRuntimeMethods().SingleOrDefault(m => m.Name == methodName); } + + static MethodInfo GetMethodInfo(Type t, string methodName, ActionMessage message) + { + var methods = (from method in t.GetRuntimeMethods() + where method.Name == methodName + let methodParameters = method.GetParameters() + where message.Parameters.Count == methodParameters.Length + && message.Parameters.OfType().Zip(methodParameters, + (parameter, info) => info.ParameterType.IsInstanceOfType(parameter.Value)).All(b => b) + select method); + + MethodInfo returnMethodInfo = null; + foreach (MethodInfo method in methods) + { + returnMethodInfo = method; + if (method.GetParameters().Zip(message.Parameters.OfType(), (info, parameter) => + parameter.Value.GetType().IsAssignableFrom(info.ParameterType) + ).All(b => b)) + { + break; + } + } + + return returnMethodInfo; + } } }