1
/* Copyright (C) 2004-2006 Versant Inc. http://www.db4o.com */
3
using System.Collections.Generic;
4
using Db4objects.Db4o.Activation;
5
using Db4objects.Db4o.Instrumentation.Cecil;
6
using Db4objects.Db4o.TA;
7
using Mono.Collections.Generic;
9
namespace Db4objects.Db4o.NativeQueries
12
using System.Collections;
13
using System.Reflection;
17
using Cecil.FlowAnalysis;
18
using Cecil.FlowAnalysis.ActionFlow;
19
using Cecil.FlowAnalysis.CodeStructure;
20
using Ast = Cecil.FlowAnalysis.CodeStructure;
24
using Expr.Cmp.Operand;
25
using NQExpression = Expr.IExpression;
29
/// Build a Db4objects.Db4o.Nativequery.Expr tree out of a predicate method definition.
31
public class QueryExpressionBuilder
33
protected static ICachingStrategy<string, AssemblyDefinition> _assemblyCachingStrategy =
34
new SingleItemCachingStrategy<string, AssemblyDefinition>( delegate(string location)
36
return AssemblyDefinition.ReadAssembly(location);
39
protected static ICachingStrategy<MethodBase, IExpression> _expressionCachingStrategy =
40
new SingleItemCachingStrategy<MethodBase, IExpression>(
41
delegate(MethodBase method)
43
MethodDefinition methodDef = GetMethodDefinition(method);
44
return AdjustBoxedValueTypes(FromMethodDefinition(methodDef));
48
public NQExpression FromMethod(MethodBase method)
50
if (method == null) throw new ArgumentNullException("method");
52
return GetCachedExpression(method);
55
private static NQExpression GetCachedExpression(MethodBase method)
57
return _expressionCachingStrategy.Get(method);
60
private static MethodDefinition GetMethodDefinition(MethodBase method)
62
string location = GetAssemblyLocation(method);
64
MethodDefinition methodDef = MethodDefinitionFor(method);
66
AssemblyDefinition assembly = _assemblyCachingStrategy.Get(location);
68
MethodDefinition methodDef = (MethodDefinition)assembly.MainModule.LookupToken(method.MetadataToken);
70
if (null == methodDef) UnsupportedPredicate(string.Format("Unable to load the definition of '{0}' from assembly '{1}'", method, location));
75
private static MethodDefinition MethodDefinitionFor(MethodBase method)
77
string location = GetAssemblyLocation(method);
78
AssemblyDefinition assembly = _assemblyCachingStrategy.Get(location);
81
TypeDefinition declaringType = FindTypeDefinition(assembly.MainModule, method.DeclaringType);
82
if (declaringType == null)
87
foreach (MethodDefinition candidate in declaringType.Methods)
89
if (candidate.Name != method.Name) continue;
90
if (candidate.Parameters.Count != method.GetParameters().Length) continue;
91
if (!ParametersMatch(candidate.Parameters, GetParameterTypes(method, assembly.MainModule))) continue;
100
return (MethodDefinition) assembly.MainModule.LookupToken(method.MetadataToken);
104
private static NQExpression AdjustBoxedValueTypes(NQExpression expression)
106
expression.Accept(new BoxedValueTypeProcessor());
110
private static IList<TypeReference> GetParameterTypes(MethodBase method, ModuleDefinition module)
112
IList<TypeReference> types = new List<TypeReference>();
113
foreach (ParameterInfo parameter in ParametersFor(method))
115
types.Add(FindTypeDefinition(module, parameter.ParameterType));
121
private static ParameterInfo[] ParametersFor(MethodBase method)
123
if (method.IsGenericMethod)
125
MethodInfo methodInfo = (MethodInfo) method;
126
return methodInfo.GetGenericMethodDefinition().GetParameters();
129
return method.DeclaringType.IsGenericType
130
? method.DeclaringType.GetGenericTypeDefinition().GetMethod(method.Name).GetParameters()
131
: method.GetParameters();
134
private static TypeDefinition FindTypeDefinition(ModuleDefinition module, Type type)
136
return IsNested(type)
137
? FindNestedTypeDefinition(module, type)
138
: FindTypeDefinition(module, type.IsGenericType ? type.Name : type.FullName);
141
private static bool IsNested(Type type)
143
return type.IsNestedPublic || type.IsNestedPrivate || type.IsNestedAssembly;
146
private static TypeDefinition FindNestedTypeDefinition(ModuleDefinition module, Type type)
148
foreach (TypeDefinition td in FindTypeDefinition(module, type.DeclaringType).NestedTypes)
150
if (td.Name == type.Name) return td;
155
private static TypeDefinition FindTypeDefinition(ModuleDefinition module, string name)
157
return module.GetType(name);
160
private static string GetAssemblyLocation(MethodBase method)
162
return method.DeclaringType.Module.FullyQualifiedName;
165
public static NQExpression FromMethodDefinition(MethodDefinition method)
167
ValidatePredicateMethodDefinition(method);
169
Expression expression = GetQueryExpression(method);
170
if (null == expression) UnsupportedPredicate("No expression found.");
172
Visitor visitor = new Visitor(method, new AssemblyResolver(_assemblyCachingStrategy));
173
expression.Accept(visitor);
174
return visitor.Expression;
177
private static void ValidatePredicateMethodDefinition(MethodDefinition method)
180
throw new ArgumentNullException("method");
181
if (1 != method.Parameters.Count)
182
UnsupportedPredicate("A predicate must take a single argument.");
183
if (0 != method.Body.ExceptionHandlers.Count)
184
UnsupportedPredicate("A predicate can not contain exception handlers.");
185
if (method.ReturnType.FullName != typeof(bool).FullName)
186
UnsupportedPredicate("A predicate must have a boolean return type.");
189
private static Expression GetQueryExpression(MethodDefinition method)
191
ActionFlowGraph afg = FlowGraphFactory.CreateActionFlowGraph(FlowGraphFactory.CreateControlFlowGraph(method));
192
return GetQueryExpression(afg);
195
private static void UnsupportedPredicate(string reason)
197
throw new UnsupportedPredicateException(reason);
200
private static void UnsupportedExpression(Expression node)
202
UnsupportedPredicate("Unsupported expression: " + ExpressionPrinter.ToString(node));
205
private static Expression GetQueryExpression(ActionFlowGraph afg)
207
IDictionary<int, Expression> variables = new Dictionary<int, Expression>();
208
ActionBlock block = afg.Blocks[0];
209
while (block != null)
211
switch (block.ActionType)
213
case ActionType.Invoke:
214
InvokeActionBlock invokeBlock = (InvokeActionBlock)block;
215
MethodInvocationExpression invocation = invokeBlock.Expression;
216
if (IsActivateInvocation(invocation)
217
|| IsNoSideEffectIndirectActivationInvocation(invocation))
219
block = invokeBlock.Next;
223
UnsupportedExpression(invocation);
226
case ActionType.ConditionalBranch:
227
UnsupportedPredicate("Conditional blocks are not supported.");
230
case ActionType.Branch:
231
block = ((BranchActionBlock)block).Target;
234
case ActionType.Assign:
236
AssignActionBlock assignBlock = (AssignActionBlock)block;
237
AssignExpression assign = assignBlock.AssignExpression;
238
VariableReferenceExpression variable = assign.Target as VariableReferenceExpression;
239
if (null == variable)
241
UnsupportedExpression(assign);
245
if (variables.ContainsKey(variable.Variable.Index))
246
UnsupportedExpression(assign.Expression);
248
variables.Add(variable.Variable.Index, assign.Expression);
249
block = assignBlock.Next;
254
case ActionType.Return:
256
Expression expression = ((ReturnActionBlock)block).Expression;
257
VariableReferenceExpression variable = expression as VariableReferenceExpression;
258
return null == variable
260
: variables[variable.Variable.Index];
267
private static bool IsNoSideEffectIndirectActivationInvocation(MethodInvocationExpression invocation)
269
MethodDefinition methodDefinition = MethodDefinitionFor(invocation);
270
if (null == methodDefinition) return false;
271
ActionFlowGraph afg = FlowGraphFactory.CreateActionFlowGraph(FlowGraphFactory.CreateControlFlowGraph(methodDefinition));
273
if (afg.Blocks.Count == 2 && afg.Blocks[0].ActionType == ActionType.Invoke)
275
InvokeActionBlock invocationBlock = (InvokeActionBlock) afg.Blocks[0];
276
return IsActivateInvocation(invocationBlock.Expression);
282
private static MethodDefinition MethodDefinitionFor(MethodInvocationExpression invocation)
284
MethodReferenceExpression methodRef = invocation.Target as MethodReferenceExpression;
285
if (null == methodRef) return null;
287
return GetMethodDefinition(methodRef);
290
private static bool IsActivateInvocation(MethodInvocationExpression invocation)
292
MethodReferenceExpression methodRef = invocation.Target as MethodReferenceExpression;
293
if (null == methodRef) return false;
294
return IsActivateMethod(methodRef.Method);
297
private static bool IsActivateMethod(MethodReference method)
299
if (method.Name != "Activate") return false;
300
return method.DeclaringType.FullName == typeof(IActivatable).FullName ||
301
IsOverridenActivateMethod(method);
304
private static bool IsOverridenActivateMethod(MethodReference method)
306
TypeDefinition declaringType = FindTypeDefinition(method.DeclaringType.Module, method.DeclaringType.FullName);
307
if (!DeclaringTypeImplementsIActivatable(declaringType)) return false;
308
if (method.Parameters.Count != 1 ||
309
method.Parameters[0].ParameterType.FullName != typeof(ActivationPurpose).FullName) return false;
314
private static bool DeclaringTypeImplementsIActivatable(TypeDefinition type)
316
foreach (TypeReference itf in type.Interfaces)
318
if (itf.FullName == typeof (IActivatable).FullName)
327
private static MethodDefinition GetMethodDefinition(MethodReferenceExpression methodRef)
329
MethodDefinition definition = methodRef.Method as MethodDefinition;
330
return definition ?? LoadExternalMethodDefinition(methodRef);
333
private static MethodDefinition LoadExternalMethodDefinition(MethodReferenceExpression methodRef)
335
MethodReference method = methodRef.Method;
336
AssemblyDefinition assemblyDef = new AssemblyResolver(_assemblyCachingStrategy).ForTypeReference(method.DeclaringType);
337
TypeDefinition type = assemblyDef.MainModule.GetType(method.DeclaringType.FullName);
338
return GetMethod(type, method);
341
private static MethodDefinition GetMethod(TypeDefinition type, MethodReference template)
343
foreach (MethodDefinition method in type.Methods)
345
if (method.Name != template.Name) continue;
346
if (method.Parameters.Count != template.Parameters.Count) continue;
347
if (!ParametersMatch(method.Parameters, template.Parameters)) continue;
356
private static bool ParametersMatch(Collection<ParameterDefinition> parameters, IList<TypeReference> templates)
358
return ParametersMatch(parameters, templates, delegate(ParameterDefinition candidate, TypeReference template)
360
return candidate.ParameterType.FullName == template.FullName;
365
private static bool ParametersMatch(IList<ParameterDefinition> parameters, IList<ParameterDefinition> templates)
367
return ParametersMatch(parameters, templates, delegate(ParameterDefinition candidate, ParameterDefinition template)
369
return candidate.ParameterType.FullName == template.ParameterType.FullName;
373
private static bool ParametersMatch<T>(IList<ParameterDefinition> parameters, IList<T> templates, ParameterMatch<T> predicate)
375
if (parameters.Count != templates.Count) return false;
377
for (int i = 0; i < parameters.Count; i++)
379
ParameterDefinition parameter = parameters[i];
380
if (!predicate(parameter, templates[i])) return false;
386
private delegate bool ParameterMatch<T>(ParameterDefinition candidate, T template);
388
class Visitor : AbstractCodeStructureVisitor
390
private object _current;
391
private int _insideCandidate;
392
readonly IList _methodDefinitionStack = new ArrayList();
393
private readonly CecilReferenceProvider _referenceProvider;
395
public Visitor(MethodDefinition topLevelMethod, AssemblyResolver resolver)
397
EnterMethodDefinition(topLevelMethod);
398
AssemblyDefinition assembly = resolver.ForType(topLevelMethod.DeclaringType);
399
_referenceProvider = CecilReferenceProvider.ForModule(assembly.MainModule);
402
private void EnterMethodDefinition(MethodDefinition method)
404
_methodDefinitionStack.Add(method);
407
private void LeaveMethodDefinition(MethodDefinition method)
409
int lastIndex = _methodDefinitionStack.Count - 1;
410
object popped = _methodDefinitionStack[lastIndex];
411
System.Diagnostics.Debug.Assert(method == popped);
412
_methodDefinitionStack.RemoveAt(lastIndex);
415
public NQExpression Expression
419
ConstValue value = _current as ConstValue;
422
return ToNQExpression(value);
424
return (NQExpression)_current;
428
private static NQExpression ToNQExpression(ConstValue value)
430
if (IsTrue(value.Value())) return BoolConstExpression.True;
431
return BoolConstExpression.False;
434
private static bool IsTrue(object o)
436
return ((IConvertible) o).ToBoolean(null);
439
private bool InsideCandidate
441
get { return _insideCandidate > 0; }
444
public override void Visit(CastExpression node)
446
node.Target.Accept(this);
449
public override void Visit(AssignExpression node)
451
UnsupportedExpression(node);
454
public override void Visit(VariableReferenceExpression node)
456
UnsupportedExpression(node);
459
public override void Visit(ArgumentReferenceExpression node)
461
UnsupportedExpression(node);
464
public override void Visit(UnaryExpression node)
466
switch (node.Operator)
468
case UnaryOperator.Not:
474
UnsupportedExpression(node);
479
public override void Visit(Ast.BinaryExpression node)
481
switch (node.Operator)
483
case BinaryOperator.ValueEquality:
484
PushComparison(node.Left, node.Right, ComparisonOperator.ValueEquality);
487
case BinaryOperator.ValueInequality:
488
PushComparison(node.Left, node.Right, ComparisonOperator.ValueEquality);
492
case BinaryOperator.LessThan:
493
PushComparison(node.Left, node.Right, ComparisonOperator.Smaller);
496
case BinaryOperator.GreaterThan:
497
PushComparison(node.Left, node.Right, ComparisonOperator.Greater);
500
case BinaryOperator.GreaterThanOrEqual:
501
PushComparison(node.Left, node.Right, ComparisonOperator.Smaller);
505
case BinaryOperator.LessThanOrEqual:
506
PushComparison(node.Left, node.Right, ComparisonOperator.Greater);
510
case BinaryOperator.LogicalOr:
511
Push(new OrExpression(Convert(node.Left), Convert(node.Right)));
514
case BinaryOperator.LogicalAnd:
515
Push(new AndExpression(Convert(node.Left), Convert(node.Right)));
519
UnsupportedExpression(node);
524
private void Negate()
526
NQExpression top = (NQExpression)Pop();
527
NotExpression topNot = top as NotExpression;
533
Push(new NotExpression(top));
536
private void PushComparison(Expression lhs, Expression rhs, ComparisonOperator op)
541
object right = Pop();
543
bool areOperandsSwapped = IsCandidateFieldValue(right);
544
if (areOperandsSwapped)
551
AssertType(left, typeof(FieldValue), lhs);
552
AssertType(right, typeof(IComparisonOperand), rhs);
554
Push(new ComparisonExpression((FieldValue)left, (IComparisonOperand)right, op));
556
if (areOperandsSwapped && !op.IsSymmetric())
562
private static bool IsCandidateFieldValue(object o)
564
FieldValue value = o as FieldValue;
565
if (value == null) return false;
566
return value.Root() is CandidateFieldRoot;
569
public override void Visit(MethodInvocationExpression node)
571
MethodReferenceExpression methodRef = node.Target as MethodReferenceExpression;
572
if (null == methodRef)
573
UnsupportedExpression(node);
575
MethodReference method = methodRef.Method;
576
if (IsOperator(method))
578
ProcessOperatorMethodInvocation(node, method);
582
if (IsSystemString(method.DeclaringType))
584
ProcessStringMethod(node, methodRef);
588
ProcessRegularMethodInvocation(node, methodRef);
591
private static bool IsSystemString(TypeReference type)
593
return type.FullName == "System.String";
596
private void ProcessStringMethod(MethodInvocationExpression node, MethodReferenceExpression methodRef)
598
MethodReference method = methodRef.Method;
600
if (method.Parameters.Count != 1
601
|| !IsSystemString(method.Parameters[0].ParameterType))
603
UnsupportedExpression(methodRef);
609
PushComparison(methodRef.Target, node.Arguments[0], ComparisonOperator.Contains);
613
PushComparison(methodRef.Target, node.Arguments[0], ComparisonOperator.StartsWith);
617
PushComparison(methodRef.Target, node.Arguments[0], ComparisonOperator.EndsWith);
621
PushComparison(methodRef.Target, node.Arguments[0], ComparisonOperator.ValueEquality);
625
UnsupportedExpression(methodRef);
630
private void ProcessRegularMethodInvocation(MethodInvocationExpression node, MethodReferenceExpression methodRef)
632
if (node.Arguments.Count != 0)
633
UnsupportedExpression(node);
635
Expression target = methodRef.Target;
636
switch (target.CodeElementType)
638
case CodeElementType.ThisReferenceExpression:
639
if (!InsideCandidate)
640
UnsupportedExpression(node);
641
ProcessCandidateMethodInvocation(node, methodRef);
644
case CodeElementType.ArgumentReferenceExpression:
645
ProcessCandidateMethodInvocation(node, methodRef);
649
Push(ToFieldValue(target));
650
ProcessCandidateMethodInvocation(node, methodRef);
655
private void ProcessOperatorMethodInvocation(MethodInvocationExpression node, MemberReference methodReference)
657
switch (methodReference.Name)
660
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.ValueEquality);
663
case "op_Inequality":
664
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.ValueEquality);
668
// XXX: check if the operations below are really supported for the
669
// data types in question
670
case "op_GreaterThanOrEqual":
671
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.Smaller);
675
case "op_LessThanOrEqual":
676
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.Greater);
681
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.Smaller);
684
case "op_GreaterThan":
685
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.Greater);
689
UnsupportedExpression(node);
694
private void ProcessCandidateMethodInvocation(Expression methodInvocationExpression, MethodReferenceExpression methodRef)
696
MethodDefinition method = GetMethodDefinition(methodRef);
698
UnsupportedExpression(methodInvocationExpression);
700
AssertMethodCanBeVisited(methodInvocationExpression, method);
702
Expression expression = GetQueryExpression(method);
703
if (null == expression)
704
UnsupportedExpression(methodInvocationExpression);
706
EnterCandidateMethod(method);
713
LeaveCandidateMethod(method);
717
private void AssertMethodCanBeVisited(Expression methodInvocationExpression, MethodDefinition method)
719
if (_methodDefinitionStack.Contains(method))
720
UnsupportedExpression(methodInvocationExpression);
723
private void EnterCandidateMethod(MethodDefinition method)
725
EnterMethodDefinition(method);
729
private void LeaveCandidateMethod(MethodDefinition method)
732
LeaveMethodDefinition(method);
735
private static bool IsOperator(MethodReference method)
737
return !method.HasThis && method.Name.StartsWith("op_") && 2 == method.Parameters.Count;
740
public override void Visit(FieldReferenceExpression node)
742
PushFieldValueForTarget(node, node.Target);
745
private void PushFieldValueForTarget(FieldReferenceExpression node, Expression target)
747
switch (target.CodeElementType)
749
case CodeElementType.ArgumentReferenceExpression:
750
PushFieldValue(CandidateFieldRoot.Instance, node.Field);
753
case CodeElementType.ThisReferenceExpression:
756
if (_current != null)
758
FieldValue current = PopFieldValue(node);
759
PushFieldValue(current, node.Field);
763
PushFieldValue(CandidateFieldRoot.Instance, node.Field);
768
PushFieldValue(PredicateFieldRoot.Instance, node.Field);
772
case CodeElementType.MethodInvocationExpression:
773
case CodeElementType.FieldReferenceExpression:
774
FieldValue value = ToFieldValue(target);
775
PushFieldValue(value, node.Field);
778
case CodeElementType.CastExpression:
779
PushFieldValueForTarget(node, ((CastExpression)node.Target).Target);
783
UnsupportedExpression(node);
788
private void PushFieldValue(IComparisonOperandAnchor value, FieldReference field)
790
Push(new FieldValue(value, _referenceProvider.ForCecilField(field)));
793
public override void Visit(LiteralExpression node)
795
Push(new ConstValue(node.Value));
798
NQExpression Convert(Expression node)
800
return ReconstructNullComparisonIfNecessary(node);
803
private NQExpression ReconstructNullComparisonIfNecessary(Expression node)
808
FieldValue fieldValue = top as FieldValue;
809
if (fieldValue == null)
811
AssertType(top, typeof(NQExpression), node);
812
return (NQExpression)top;
817
new ComparisonExpression(
819
new ConstValue(null),
820
ComparisonOperator.ValueEquality));
823
FieldValue ToFieldValue(Expression node)
826
return PopFieldValue(node);
829
private FieldValue PopFieldValue(Expression node)
831
return (FieldValue)Pop(node, typeof(FieldValue));
834
void Push(object value)
836
Assert(_current == null, "expression stack must be empty before Push");
840
object Pop(Expression node, Type expectedType)
842
object value = Pop();
843
AssertType(value, expectedType, node);
847
private static void AssertType(object value, Type expectedType, Expression sourceExpression)
849
Type actualType = value.GetType();
850
if (!expectedType.IsAssignableFrom(actualType))
852
UnsupportedPredicate(
853
string.Format("Unsupported expression: {0}. Unexpected type on stack. Expected: {1}, Got: {2}.",
854
ExpressionPrinter.ToString(sourceExpression), expectedType, actualType));
860
Assert(_current != null, "expression stack is empty");
861
object value = _current;
866
private static void Assert(bool condition, string message)
868
System.Diagnostics.Debug.Assert(condition, message);
873
internal class BoxedValueTypeProcessor : TraversingExpressionVisitor
875
override public void Visit(ComparisonExpression expression)
877
TypeReference fieldType = GetFieldType(expression.Left());
878
if (!fieldType.IsValueType) return;
880
ConstValue constValue = expression.Right() as ConstValue;
881
if (constValue == null) return;
883
AdjustConstValue(fieldType, constValue);
886
private static TypeReference GetFieldType(FieldValue field)
888
return ((CecilFieldRef) field.Field).FieldType;
891
private static void AdjustConstValue(TypeReference typeRef, ConstValue constValue)
893
object value = constValue.Value();
894
if (!value.GetType().IsValueType) return;
896
System.Type type = ResolveTypeReference(typeRef);
897
if (!type.IsEnum || value.GetType() == type) return;
899
constValue.Value(Enum.ToObject(type, value));
902
private static Type ResolveTypeReference(TypeReference typeRef)
904
Assembly assembly = LoadAssembly(typeRef.Scope);
905
return assembly.GetType(typeRef.FullName.Replace('/', '+'), true);
908
private static Assembly LoadAssembly(IMetadataScope scope)
910
AssemblyNameReference nameRef = scope as AssemblyNameReference;
911
if (null != nameRef) return Assembly.Load(nameRef.FullName);
912
ModuleDefinition moduleDef = scope as ModuleDefinition;
913
return LoadAssembly(moduleDef.Assembly.Name);
b'\\ No newline at end of file'