Commit 9011b009 authored by Ramon's avatar Ramon
Browse files

initial code base, still contains a bug where the constraint solver could not...

initial code base, still contains a bug where the constraint solver could not find a solution for some cases
parent 2b43ccf7
Pipeline #2700 skipped
src/invlang/inverter/Z3Wrapper.java
package invlang.inverter;
import static invlang.util.Constants.*;
import static org.chocosolver.solver.constraints.IntConstraintFactory.*;
import static org.chocosolver.solver.constraints.LogicalConstraintFactory.*;
import org.chocosolver.solver.constraints.IntConstraintFactory;
import org.chocosolver.solver.constraints.LogicalConstraintFactory;
import invlang.mapperReader.InvLangHandler;
import invlang.semantics.State;
import invlang.semantics.programTree.expressionTree.BinaryOperator;
import invlang.semantics.programTree.expressionTree.Expression;
import invlang.semantics.programTree.expressionTree.LiteralLeaf;
import invlang.semantics.programTree.expressionTree.UnaryOperator;
import invlang.semantics.programTree.expressionTree.VariableLeaf;
import invlang.types.Enum;
import invlang.types.EnumValue;
import invlang.types.Flag;
import invlang.types.FlagSet;
import invlang.types.Type;
import invlang.util.Constants.Choco;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
import org.chocosolver.solver.ResolutionPolicy;
import org.chocosolver.solver.Solver;
import org.chocosolver.solver.constraints.Constraint;
import org.chocosolver.solver.constraints.set.SetConstraintsFactory;
import org.chocosolver.solver.search.strategy.ISF;
import org.chocosolver.solver.trace.Chatterbox;
import org.chocosolver.solver.variables.BoolVar;
import org.chocosolver.solver.variables.IntVar;
import org.chocosolver.solver.variables.SetVar;
import org.chocosolver.solver.variables.Variable;
import org.chocosolver.solver.variables.VariableFactory;
import org.chocosolver.util.ESat;
import com.sun.javafx.binding.IntegerConstant;
public class ChocoWrapper {
private final Map<String, Variable> chocoVars;
private final Solver solver;
private final IntVar minusOne;
private final Map<String, Type> inputTypes;
private final Random random;
/**
* Use an minimalisation-target to get the solutions close to random points.
*/
private final boolean useRandomisation;
/**
* Does not use randomization
* @param state
* @param inputTypes
*/
public ChocoWrapper(State state, Map<String, Type> inputTypes) {
this(state, inputTypes, false, null);
}
/**
* Uses randomization
* @param state
* @param inputTypes
* @param seed
*/
public ChocoWrapper(State state, Map<String, Type> inputTypes, long seed) {
this(state, inputTypes, true, new Random(seed));
}
private ChocoWrapper(State state, Map<String, Type> inputTypes, boolean useRandomization, Random r) {
this.random = r;
this.useRandomisation = useRandomization;
this.chocoVars = new HashMap<>();
this.solver = new Solver();
this.inputTypes = inputTypes;
this.minusOne = VariableFactory.fixed(-1, solver);
// determine the choco-representation of all variables
for (Entry<String, Type> varType : inputTypes.entrySet()) {
Type type = varType.getValue();
String varName = varType.getKey();
if (type == Type.BOOLEAN) {
this.chocoVars.put(varName,VariableFactory.bool(varName, solver));
} else if (type == Type.INT32) {
this.chocoVars.put(varName, VariableFactory.bounded(varName,
VariableFactory.MIN_INT_BOUND,
VariableFactory.MAX_INT_BOUND,
solver));
} else if (type == Type.FLAGSET) {
this.chocoVars.put(varName,
VariableFactory.set(varName, 0, Flag.values().length-1,
solver));
} else {
Enum e = ((invlang.types.Enum.EnumType) type).getEnum();
this.chocoVars.put(varName,
VariableFactory.enumerated(varName, 0, e.getNumberOfElements()-1, solver));
}
}
}
/**
* returns a variable representing the sum of distances (of every dimension) to
* a random point. I.e. if possible, all integer variables get a random value.
* @return a target to minimize (for randomization) or null, of no integers are used.
*/
private IntVar distToRandomPoint() {
if (this.chocoVars.size() == 0) {
return null;
}
IntVar[] diffs = new IntVar[this.chocoVars.size()];
int nrInts = 0;
int max = VariableFactory.MAX_INT_BOUND / this.chocoVars.size();
int min = VariableFactory.MIN_INT_BOUND / this.chocoVars.size();
for (Variable var : this.chocoVars.values()) {
if (var instanceof IntVar && !(var instanceof BoolVar)) {
IntVar intVar = (IntVar) var;
IntVar diff = VariableFactory.bounded("randomizer_" + var.getName(), VariableFactory.MIN_INT_BOUND, VariableFactory.MAX_INT_BOUND, solver);
int randomTarget = this.random.nextInt(max - min) + min;
solver.post(IntConstraintFactory.distance(intVar, newConstant(randomTarget), "=", diff));
diffs[nrInts] = diff;
nrInts++;
}
}
if (nrInts == 0) {
return null;
}
IntVar sum = VariableFactory.bounded("randomizer", VariableFactory.MIN_INT_BOUND, VariableFactory.MAX_INT_BOUND, solver);
IntVar[] diffConstraint = new IntVar[nrInts];
System.arraycopy(diffs, 0, diffConstraint, 0, nrInts);
solver.post(IntConstraintFactory.sum(diffConstraint, sum));
return sum;
}
/**
* Adds this expression to the constraint solver
* @param constraints
*/
public Map<String, Object> resolve(Expression constraints) {
constraints = constraints.prepareForSolver();
solver.post(arithm(toVar(constraints), Choco.EQ, 1));
if (useRandomisation) {
IntVar randomOptimization = distToRandomPoint();
if (randomOptimization != null) {
solver.findOptimalSolution(ResolutionPolicy.MINIMIZE, randomOptimization);
}
}
/*if (random != null) {
Collection<Variable> varCollection = this.chocoVars.values();
IntVar[] vars = new IntVar[varCollection.size()];
int i = 0;
for (Variable var : this.chocoVars.values()) {
vars[i++] = (IntVar) var;
}
solver.set(ISF.custom(ISF.random_var_selector(random.nextLong()), ISF.random_value_selector(random.nextLong()), vars));
}*/
if (!solver.findSolution()) {
return null;
}
Map<String, Object> values = new HashMap<>();
for (Entry<String, Type> input : this.inputTypes.entrySet()) {
Type type = input.getValue();
String varName = input.getKey();
Variable var = this.chocoVars.get(varName);
Object result;
if (type == Type.BOOLEAN) {
ESat val = ((BoolVar) var).getBooleanValue();
result = val == ESat.TRUE ? true :
val == ESat.FALSE ? false :
null;
} else if (type == Type.INT32) {
result = ((IntVar) var).getValue();
} else if (type == Type.FLAGSET) {
result = ChocoWrapper.asFlagSet((SetVar) var);
} else if (type instanceof invlang.types.Enum.EnumType) {
result = ((invlang.types.Enum.EnumType) type).getEnum().getValue(((IntVar) var).getValue());
} else {
throw new RuntimeException();
}
values.put(varName, result);
}
return values;
}
private RuntimeException unpostable(Expression e) {
return new RuntimeException("Cannot post expression as constraint: " +e);
}
private int intCounter = 0;
private IntVar newInt() {
return VariableFactory.bounded("i" + (intCounter++), VariableFactory.MIN_INT_BOUND,
VariableFactory.MAX_INT_BOUND, solver);
}
private IntVar newConstant(int val) {
return VariableFactory.fixed(val, solver);
}
private int boolCounter = 0;
private BoolVar newBool() {
return VariableFactory.bool("b" + (boolCounter++), solver);
}
/*private IntVar toVar(int i) {
IntVar tmp = newInt();
solver.post(arithm(tmp, Choco.EQ, i));
return tmp;
}*/
@SuppressWarnings("unchecked")
private <T extends Variable> T toVar(Expression expr) {
if (expr instanceof LiteralLeaf) {
Object value = ((LiteralLeaf) expr).getValue();
if (value instanceof Integer) {
return (T) VariableFactory.fixed((Integer) value, solver);
} else if (value instanceof Boolean){
return (T) VariableFactory.fixed((Boolean) value, solver);
} else if (value instanceof EnumValue) {
return (T) VariableFactory.fixed(((EnumValue) value).getOrdinal(), solver);
} else if (value instanceof FlagSet) {
return (T) fromFlagSet((FlagSet) value);
} else {
throw new RuntimeException("Unexpected expr '" + value + "' with class '" + value.getClass().getCanonicalName() + "'");
}
} else if (expr instanceof VariableLeaf) {
String varName = ((VariableLeaf) expr).getVarName();
Variable var = this.chocoVars.get(varName);
if (var == null) {
throw new RuntimeException("Cannot find variable '" + varName + "' in map " + this.chocoVars);
}
return (T) var;
} else if (expr instanceof UnaryOperator) {
UnaryOperator unaryExpr = (UnaryOperator) expr;
switch (unaryExpr.operator().toString()) {
case UNARYMINUS:
IntVar tmpInt = newInt();
solver.post(times(toVar(unaryExpr.getChild()), minusOne, tmpInt));
//solver.post(arithm(tmpInt, "=", toVar(unaryExpr.getChild()), "*", -1));
return (T) tmpInt;
case NEG:
BoolVar tmpBool = newBool();
//solver.post(times(toVar(unaryExpr.getChild()), this.minusOne, tmpBool));
solver.post(arithm(toVar(unaryExpr.getChild()), Choco.NEQ, solver.ONE));
return (T) tmpBool;
default:
throw unpostable(expr);
}
} else if (expr instanceof BinaryOperator) {
BinaryOperator binExpr = (BinaryOperator) expr;
switch (binExpr.operator().toString()) {
case PLUS:
IntVar tmpInt = newInt();
// generate a flat array, such that a+b+c+d becomes one constraint instead of 3
solver.post(sum(asFlatIntArray(binExpr), tmpInt));
return (T) tmpInt;
case MINUS:
throw new RuntimeException("Cannot solve expressions with binary minus: "
+ "pre-process to rewrite to unary minus\n" + expr);
/*tmpInt = newInt();
solver.post(arithm(tmpInt, Choco.EQ, toVar(binExpr.getLeftChild()), Choco.MINUS, toVar(binExpr.getRightChild())));
return (T)tmpInt;*/
case MULT:
tmpInt = newInt();
solver.post(times(toVar(binExpr.getLeftChild()), toVar(binExpr.getRightChild()), tmpInt));
return (T)tmpInt;
case DIV:
tmpInt = newInt();
solver.post(eucl_div(toVar(binExpr.getLeftChild()), toVar(binExpr.getRightChild()), tmpInt));
return (T)tmpInt;
case AND:
Constraint andExpr = and(asFlatBoolArray(binExpr));
BoolVar tmpBool = newBool();
// the result of (a & b & c) is x, where (if x then (a & b & c) else !(a & b & c))
ifThenElse(tmpBool, andExpr, not(andExpr));
return (T) tmpBool;
case OR:
Constraint orExpr = or(asFlatBoolArray(binExpr));
tmpBool = newBool();
// the result of (a & b & c) is x, where (if x then (a & b & c) else !(a & b & c))
ifThenElse(tmpBool, orExpr, not(orExpr));
return (T) tmpBool;
case HAS:
tmpBool = newBool();
Constraint subset = SetConstraintsFactory.subsetEq(new SetVar[]{
toVar(binExpr.getRightChild()), toVar(binExpr.getLeftChild())
});
ifThenElse(tmpBool, subset, not(subset));
case EQ:
Variable left = toVar(binExpr.getLeftChild()), right = toVar(binExpr.getRightChild());
if (left instanceof IntVar && right instanceof IntVar) {
return toComparisonVariable((IntVar)left, (IntVar)right, Choco.EQ);
} else if (left instanceof SetVar && right instanceof SetVar) {
Constraint setEq = SetConstraintsFactory.all_equal(new SetVar[]{(SetVar) left, (SetVar) right});
tmpBool = newBool();
ifThenElse(tmpBool, setEq, not(setEq));
return (T) tmpBool;
} else {
throw new RuntimeException();
}
case NEQ:
left = toVar(binExpr.getLeftChild());
right = toVar(binExpr.getRightChild());
if (left instanceof IntVar && right instanceof IntVar) {
return toComparisonVariable((IntVar)left, (IntVar)right, Choco.EQ);
} else if (left instanceof SetVar && right instanceof SetVar) {
Constraint setEq = SetConstraintsFactory.all_equal(new SetVar[]{(SetVar) left, (SetVar) right});
tmpBool = newBool();
ifThenElse(tmpBool, not(setEq), setEq);
return (T) tmpBool;
} else {
throw new RuntimeException();
}
case LESS:
return toComparisonVariable(binExpr, Choco.LESS);
case GREATER:
return toComparisonVariable(binExpr, Choco.GREATER);
case LE:
return toComparisonVariable(binExpr, Choco.LE);
case GE:
return toComparisonVariable(binExpr, Choco.GE);
default: throw unpostable(expr);
}
} else {
throw unpostable(expr);
}
}
@SuppressWarnings("unchecked")
public <T extends Variable> T toComparisonVariable(IntVar left, IntVar right, String chocoComp) {
BoolVar tmpBool = newBool();
Constraint compExpr = arithm(left, chocoComp, right);
ifThenElse(tmpBool, compExpr, not(compExpr));
return (T) tmpBool;
}
public <T extends Variable> T toComparisonVariable(BinaryOperator binExpr, String chocoComp) {
Variable left = toVar(binExpr.getLeftChild()), right = toVar(binExpr.getRightChild());
if (left instanceof IntVar) {
return toComparisonVariable((IntVar)left, (IntVar)right, chocoComp);
} else {
throw new RuntimeException();
}
}
public IntVar[] asFlatIntArray(BinaryOperator binExpr) {
ArrayList<Expression> flatList = binExpr.flatten();
IntVar[] flatArray = new IntVar[flatList.size()];
for (int i = 0; i < flatList.size(); i++) {
flatArray[i] = toVar(flatList.get(i));
}
return flatArray;
}
public BoolVar[] asFlatBoolArray(BinaryOperator binExpr) {
ArrayList<Expression> flatList = binExpr.flatten();
BoolVar[] flatArray = new BoolVar[flatList.size()];
for (int i = 0; i < flatList.size(); i++) {
flatArray[i] = toVar(flatList.get(i));
}
return flatArray;
}
private SetVar fromFlagSet(FlagSet flags) {
int[] ints = new int[flags.size()];
int i = 0;
for (Flag flag : flags) {
ints[i++] = flag.ordinal();
}
return VariableFactory.set("anonymous_set", ints, solver);
}
private static FlagSet asFlagSet(SetVar set) {
int size = set.getEnvelopeSize();
int[] setVals = new int[size];
if (size != 0) {
setVals[0] = set.getEnvelopeFirst();
for (int i = 1; i < size; i++) {
setVals[i] = set.getEnvelopeNext();
}
}
return new FlagSet(setVals);
}
}
package invlang.inverter;
import invlang.semantics.EnumDefinitions;
import invlang.semantics.State;
import invlang.semantics.programTree.IfElse;
import invlang.semantics.programTree.OutputStatement;
import invlang.semantics.programTree.StatementList;
import invlang.semantics.programTree.expressionTree.BinaryOperator;
import invlang.semantics.programTree.expressionTree.Expression;
import invlang.semantics.programTree.expressionTree.LiteralLeaf;
import invlang.semantics.programTree.expressionTree.NextLeaf;
import invlang.semantics.programTree.expressionTree.UnaryOperator;
import invlang.semantics.programTree.expressionTree.VariableLeaf;
import invlang.types.Type;
import invlang.types.TypeEnvironment;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.chocosolver.solver.constraints.Constraint;
import static invlang.util.Constants.*;
public class ConstraintFinder {
private final TypeEnvironment types;
private final boolean useNextOperator;
/**
* Can create expression trees representing the characteristic of the input-
* output-relation of a program.
* @param types The type environment containing all variable types
* @param useNextOperator wrap left-hand sides of assignments in a next-operator
*/
public ConstraintFinder(TypeEnvironment types, boolean useNextOperator) {
this.types = types;
this.useNextOperator = useNextOperator;
}
/**
* Find an expression representing the constraints on the inputs
* @param program the main branch of the program
* @param expectedOutputs the values that the output of the function should have
* @return For every path, a set of conjunct conditions is found. The disjunction of
* all these paths is returned.
*/
public Expression getProgramConstraints(StatementList program) {
return getProgramConstraints(program, null, null);
}
/**
* Find an expression representing the constraints on the inputs
* @param program the main branch of the program
* @param state used for reduction of expressions, may be null if reductions is not required
* @param expectedOutputs the values that the output of the function should have. May be null
* if reduction is not required
* @return For every path, a set of conjunct conditions is found. The disjunction of
* all these paths is returned.
*/
public Expression getProgramConstraints(StatementList program,
State state, Map<String, Object> expectedOutputs) {
Expression main = null;
for (OutputStatement output : program.outputs()) {
if (state != null && output.expression instanceof LiteralLeaf) {
// small optimization: if output is of form <out = const>, filter
// directly -> false if equality doesn't hold, no constraint if it does
Object value = output.expression.getValue(state, this.types);
if (!value.equals(expectedOutputs.get(output.varName))) {
return new LiteralLeaf(false, Type.BOOLEAN);
}
} else {
Expression newConstraint = new BinaryOperator(
this.useNextOperator
? new NextLeaf(output.varName, types.getVariableType(output.varName))
: new VariableLeaf(output.varName, types.getVariableType(output.varName))
, output.expression, EQ);
main = main == null ? newConstraint : new BinaryOperator(main, newConstraint, AND);
}
}
for (IfElse ifElse : program.branches()) {
Expression branch1 = getProgramConstraints(ifElse.branch1, state, expectedOutputs),
branch2 = getProgramConstraints(ifElse.branch2, state, expectedOutputs);
branch1 = new BinaryOperator(branch1, ifElse.condition, AND);
branch2 = new BinaryOperator(branch2, ifElse.condition.negate(), AND);
Expression branchExpr = new BinaryOperator(branch1, branch2, OR);
main = main == null ? branchExpr : new BinaryOperator(main, branchExpr, AND);
}
return main == null ? new LiteralLeaf(true, Type.BOOLEAN) : main;
}
}
This diff is collapsed.
package invlang.inverter;
import invlang.semantics.programTree.expressionTree.Expression;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.chocosolver.solver.Solver;
import org.chocosolver.solver.constraints.Constraint;
public class ProgramPath implements Iterable<Expression>, Cloneable {
private final Map<String, Object> outputValues = new HashMap<>();
private final Set<Expression> inputConditions = new HashSet<>();
public ProgramPath() {
}
public void addConstraint(Expression constraint) {
this.inputConditions.add(constraint);
}
public void addOutputCondition(String varName, Object value) {
this.outputValues.put(varName, value);
}
@Override
public Iterator<Expression> iterator() {
return this.inputConditions.iterator();
}
public ProgramPath clone() {
ProgramPath path = new ProgramPath();
path.append(this);
return path;
}
/**
* Adds all conditions and constraints of the given path
* @param toAppend
*/
public void append(ProgramPath toAppend) {
this.outputValues.putAll(toAppend.outputValues);
this.inputConditions.addAll(toAppend.inputConditions);
}
/**
* Clones this path and appends the given path to it
* @param path
* @return
*/
public ProgramPath cloneAndAppend(ProgramPath toAppend) {
ProgramPath path = this.clone();
path.append(toAppend);
return path;
}
}
package invlang.inverter;
import invlang.mapperReader.InvLangHandler;
public class ReducedIntRange {
public final int reducedStart, start, length;
public ReducedIntRange(int reducedStart, int start, int length) {
this.reducedStart = reducedStart;
this.length = length;
this.start = start;
}
public int reduce(int value) {
if (!isInRange(value)) {
throw new RuntimeException("Reduced int range misused: requested value is not in range [" + start + ", " + (start + length) + ")");
}
int reducedValue = value - start + reducedStart;
return reducedValue;
}
public int expand(int reducedValue) {
if (!isReducedInRange(reducedValue)) {
throw new RuntimeException("Cannot expand " + reducedValue + ": not in range [" + start + "," + (start+length) + ")");
}
int expandedValue = start + (reducedValue - reducedStart);
return expandedValue;
}
public boolean isInRange(int value) {
return value >= start && value < start + length;
}