| 1 | :- module(predicate_data_generator, [generate_synthesis_data_from_predicate/3, | |
| 2 | generate_synthesis_data_from_predicate/5]). | |
| 3 | ||
| 4 | :- use_module(library(random)). | |
| 5 | :- use_module(library(sets), [intersect/2]). | |
| 6 | ||
| 7 | :- use_module(probsrc(bmachine), [b_load_machine_from_file/1, | |
| 8 | b_get_machine_variables/1, | |
| 9 | b_get_all_used_identifiers/1, | |
| 10 | b_get_main_filename/1, | |
| 11 | bmachine_is_precompiled/0, | |
| 12 | b_machine_precompile/0]). | |
| 13 | :- use_module(probsrc(parsercall), [parse_formula/2]). | |
| 14 | :- use_module(probsrc(translate), [translate_bexpression/2]). | |
| 15 | :- use_module(probsrc(solver_interface), [solve_predicate/3, type_check_in_machine_context/2]). | |
| 16 | :- use_module(probsrc(bsyntaxtree), [conjunct_predicates/2, get_texpr_info/2, find_identifier_uses/3]). | |
| 17 | :- use_module(probsrc('synthesis/deep_learning/ground_truth')). | |
| 18 | :- use_module(probsrc('synthesis/deep_learning/b_machine_identifier_normalization')). | |
| 19 | :- use_module(probsrc(preferences), [set_preference/2, | |
| 20 | get_preference/2]). | |
| 21 | ||
| 22 | :- use_module(synthesis(synthesis_util), [get_input_nodes_from_bindings/2, | |
| 23 | create_equality_nodes_from_example/2, | |
| 24 | b_get_typed_invariant_from_machine/1]). | |
| 25 | ||
| 26 | min_amount_of_examples(3). | |
| 27 | max_amount_of_examples(positive, 8). | |
| 28 | max_amount_of_examples(negative, 8). | |
| 29 | ||
| 30 | augment_records(5). | |
| 31 | solver_timeout_ms(10000). | |
| 32 | ||
| 33 | get_random_amount_of_examples((PAmountOfExamples,NAmountOfExamples)) :- | |
| 34 | min_amount_of_examples(MinAmountOfExamples), | |
| 35 | max_amount_of_examples(positive, PTempMaxAmountOfExamples), | |
| 36 | max_amount_of_examples(negative, NTempMaxAmountOfExamples), | |
| 37 | PMaxAmountOfExamples is PTempMaxAmountOfExamples+1, | |
| 38 | NMaxAmountOfExamples is NTempMaxAmountOfExamples+1, | |
| 39 | random(MinAmountOfExamples, PMaxAmountOfExamples, PAmountOfExamples), | |
| 40 | random(MinAmountOfExamples, NMaxAmountOfExamples, NAmountOfExamples). | |
| 41 | ||
| 42 | random_list_of_numbers(0, Acc, Acc) :- | |
| 43 | !. | |
| 44 | random_list_of_numbers(C, Acc, L) :- | |
| 45 | C1 is C-1, | |
| 46 | get_random_amount_of_examples(R), | |
| 47 | \+ member(R, Acc), | |
| 48 | !, | |
| 49 | random_list_of_numbers(C1, [R|Acc], L). | |
| 50 | random_list_of_numbers(C, Acc, L) :- | |
| 51 | random_list_of_numbers(C, Acc, L). | |
| 52 | ||
| 53 | exclude_solution([], ExclusionPred, ExclusionPred) :- | |
| 54 | !. | |
| 55 | exclude_solution(State, ExclusionPred, NewExclusionPred) :- | |
| 56 | create_equality_nodes_from_example(State, EqualityNodes), | |
| 57 | conjunct_predicates(EqualityNodes, EqualityConj), | |
| 58 | truth_or_exclude(EqualityConj,ExclusionPred, NewExclusionPred). | |
| 59 | ||
| 60 | truth_or_exclude(b(truth,pred,[]), ExclusionPred, ExclusionPred) :- | |
| 61 | !. | |
| 62 | truth_or_exclude(EqualityConj, ExclusionPred, b(conjunct(b(negation(EqualityConj),pred,[]),ExclusionPred),pred,[])). | |
| 63 | ||
| 64 | filter_machine_var_states([], _, Acc, Acc). | |
| 65 | filter_machine_var_states([VarState|T], MachineVars, Acc, States) :- | |
| 66 | get_texpr_info(VarState, Info), | |
| 67 | member(synthesis(machinevar, VarName), Info), | |
| 68 | member(b(identifier(VarName), _, _), MachineVars), | |
| 69 | !, | |
| 70 | filter_machine_var_states(T, MachineVars, [VarState|Acc], States). | |
| 71 | filter_machine_var_states([_|T], MachineVars, Acc, States) :- | |
| 72 | filter_machine_var_states(T, MachineVars, Acc, States). | |
| 73 | ||
| 74 | map_translate_var_state(Env, [], [], Env). | |
| 75 | map_translate_var_state(Env, [Ast|T], [(VarName,Type,PrettyAst)|NT], NewEnv) :- | |
| 76 | normalize_ids_in_b_ast(Env, Ast, NAst, Env1), | |
| 77 | NAst = b(_, Type, Info), | |
| 78 | member(synthesis(machinevar, VarName), Info), | |
| 79 | translate_bexpression(NAst, PrettyAst), | |
| 80 | map_translate_var_state(Env1, T, NT, NewEnv). | |
| 81 | ||
| 82 | get_amount_of_states_for_predicate(Env, 0, _, _, _, Acc, Acc, Env) :- | |
| 83 | !. | |
| 84 | get_amount_of_states_for_predicate(Env, AmountOfExamples, ExclusionPred, MachineVars, VPredicateAst, Acc, ListOfExamples, NewEnv) :- | |
| 85 | AmountOfExamples1 is AmountOfExamples-1, | |
| 86 | solve_predicate(b(conjunct(ExclusionPred,VPredicateAst),pred,[]), _, Solution), | |
| 87 | Solution = solution(Bindings), | |
| 88 | get_input_nodes_from_bindings(Bindings, TempState), | |
| 89 | filter_machine_var_states(TempState, MachineVars, [], State), | |
| 90 | State \== [], | |
| 91 | exclude_solution(State, ExclusionPred, NewExclusionPred), | |
| 92 | map_translate_var_state(Env, State, PrettyState, Env1), | |
| 93 | !, | |
| 94 | get_amount_of_states_for_predicate(Env1, AmountOfExamples1, NewExclusionPred, MachineVars, VPredicateAst, [PrettyState|Acc], ListOfExamples, NewEnv). | |
| 95 | % cancel if no solution found | |
| 96 | get_amount_of_states_for_predicate(Env, _, _, _, _, Acc, Acc, Env). | |
| 97 | ||
| 98 | get_augmented_states_for_predicate(Env, AR, MachineVars, UsedComponents, PredicateAst, AugmentedSetOfData, NewEnv) :- | |
| 99 | random_list_of_numbers(AR, [], RandomNrList), | |
| 100 | get_augmented_states_for_predicate_rand(Env, RandomNrList, MachineVars, UsedComponents, PredicateAst, AugmentedSetOfData, NewEnv). | |
| 101 | ||
| 102 | get_augmented_states_for_predicate_rand(Env, [], _, _, _, [], Env). | |
| 103 | get_augmented_states_for_predicate_rand(Env, [AmountOfExamples|T], MachineVars, UsedComponents, PredicateAst, [(PositiveStates,NegativeStates,UsedComponents)|NT], NewEnv) :- | |
| 104 | AmountOfExamples = (PAmountOfExamples,NAmountOfExamples), | |
| 105 | get_amount_of_states_for_predicate(Env, PAmountOfExamples, b(truth,pred,[]), MachineVars, PredicateAst, [], PositiveStates, Env1), | |
| 106 | get_amount_of_states_for_predicate(Env1, NAmountOfExamples, b(truth,pred,[]), MachineVars, b(negation(PredicateAst),pred,[]), [], NegativeStates, Env2), | |
| 107 | length(PositiveStates, LPos), | |
| 108 | length(NegativeStates, LNeg), | |
| 109 | Amount is LPos+LNeg, | |
| 110 | min_amount_of_examples(MinAmountOfExamples), | |
| 111 | Amount >= MinAmountOfExamples, | |
| 112 | !, | |
| 113 | get_augmented_states_for_predicate_rand(Env2, T, MachineVars, UsedComponents, PredicateAst, NT, NewEnv). | |
| 114 | get_augmented_states_for_predicate_rand(Env, [_|T], MachineVars, UsedComponents, PredicateAst, NT, NewEnv) :- | |
| 115 | get_augmented_states_for_predicate_rand(Env, T, MachineVars, UsedComponents, PredicateAst, NT, NewEnv). | |
| 116 | ||
| 117 | %% remove_foolish_sub_predicates(+AllUsedIds, +TPredicateAst, -PredicateAst). | |
| 118 | % | |
| 119 | % Remove predicates that do not refer to any identifier within a machine (AllUsedIds). | |
| 120 | % For instance, 'x: NAT & x < 10 & NAT \/ NAT1 = NAT' will be reduced to 'x: NAT & x < 10' with x being a machine variable. | |
| 121 | % TODO: remove typing predicates ??? | |
| 122 | remove_foolish_sub_predicates(AllUsedIds, Ast, NewAst) :- | |
| 123 | Ast = b(Node,_,_), | |
| 124 | Node =.. [Functor, Lhs, Rhs], | |
| 125 | (Node == conjunct ; Node == disjunct), | |
| 126 | !, | |
| 127 | remove_foolish_sub_predicates_binary(AllUsedIds, Ast, Functor, Lhs, Rhs, NewAst). | |
| 128 | % fail if sub-predicate does not use any ids from AllUsedIds | |
| 129 | remove_foolish_sub_predicates(AllUsedIds, Ast, Ast) :- | |
| 130 | find_identifier_uses(Ast, [], UsedIds), | |
| 131 | intersect(AllUsedIds, UsedIds), | |
| 132 | !. | |
| 133 | ||
| 134 | remove_foolish_sub_predicates_binary(AllUsedIds, b(_,Type,Info), Functor, Lhs, Rhs, b(NNode,Type,Info)) :- | |
| 135 | remove_foolish_sub_predicates(AllUsedIds, Lhs, NLhs), | |
| 136 | remove_foolish_sub_predicates(AllUsedIds, Rhs, NRhs),!, | |
| 137 | NNode =.. [Functor, NLhs, NRhs]. | |
| 138 | remove_foolish_sub_predicates_binary(AllUsedIds, _, _, Lhs, Rhs, NRhs) :- | |
| 139 | \+ remove_foolish_sub_predicates(AllUsedIds, Lhs, _), | |
| 140 | remove_foolish_sub_predicates(AllUsedIds, Rhs, NRhs),!. | |
| 141 | remove_foolish_sub_predicates_binary(AllUsedIds, _, _, Lhs, Rhs, NLhs) :- | |
| 142 | remove_foolish_sub_predicates(AllUsedIds, Lhs, NLhs), | |
| 143 | \+ remove_foolish_sub_predicates(AllUsedIds, Rhs, _),!. | |
| 144 | % fail if both sides do not use any ids from AllUsedIds | |
| 145 | ||
| 146 | %% generate_synthesis_data_from_predicate(+RawPredicate, -GeneratedData). | |
| 147 | % | |
| 148 | % Generate positive and negative examples for a pretty-printed predicate and | |
| 149 | % extract the ground truth B components as used by the program synthesis tool. | |
| 150 | % Data is a list of triples (PositiveExamples, NegativeExamples, GroundTruth) considering | |
| 151 | % data augmentation. Each example is a set of triples (MachineVar, Type, Value). | |
| 152 | % Note: Predicate assumes that the B or Event-B machine is loaded that RawPredicate originates from. | |
| 153 | % Otherwise, fails silently. | |
| 154 | generate_synthesis_data_from_predicate(MachinePath, RawPredicate, AugmentedSetOfData) :- | |
| 155 | augment_records(AR), | |
| 156 | solver_timeout_ms(SolverTimeoutMs), | |
| 157 | generate_synthesis_data_from_predicate(MachinePath, AR, SolverTimeoutMs, RawPredicate, AugmentedSetOfData). | |
| 158 | ||
| 159 | load_b_machine_if_unloaded(MachinePath) :- | |
| 160 | bmachine_is_precompiled, | |
| 161 | b_get_main_filename(LoadedMachinePath), | |
| 162 | atom_concat(_, MachinePath, LoadedMachinePath), | |
| 163 | !. | |
| 164 | load_b_machine_if_unloaded(MachinePath) :- | |
| 165 | b_load_machine_from_file(MachinePath), | |
| 166 | b_machine_precompile. | |
| 167 | ||
| 168 | %% generate_synthesis_data_from_predicate(+AugmentRecords, +SolverTimeoutMs, +RawPredicate, -GeneratedData). | |
| 169 | % | |
| 170 | % AugmentRecords is the amount of data augmentations. | |
| 171 | generate_synthesis_data_from_predicate(MachinePath, AugmentRecords, SolverTimeoutMs, RawPredicate, AugmentedSetOfData) :- | |
| 172 | load_b_machine_if_unloaded(MachinePath), | |
| 173 | \+ current_machine_uses_records, | |
| 174 | set_desired_preferences(SolverTimeoutMs, OldKodkodPref, OldTimeOutPref, OldRandPref), | |
| 175 | atom_codes(RawPredicate, RawPredicateCodes), | |
| 176 | parse_formula(RawPredicateCodes, ParsedPredicate), | |
| 177 | type_check_in_machine_context([ParsedPredicate], TParsedPredicates), | |
| 178 | TParsedPredicates = [TPredicateAst], | |
| 179 | b_get_machine_variables(MachineVars), | |
| 180 | b_get_all_used_identifiers(AllUsedIds), | |
| 181 | remove_foolish_sub_predicates(AllUsedIds, TPredicateAst, PredicateAst), | |
| 182 | get_library_components_from_pred_or_expr(PredicateAst, UsedComponents), | |
| 183 | get_normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames), | |
| 184 | Env = [[], NormalizedSets, NormalizedIds, NOperationNames], | |
| 185 | b_get_typed_invariant_from_machine(Invariant), | |
| 186 | PredWithInv = b(conjunct(Invariant, PredicateAst),pred,[]), | |
| 187 | get_augmented_states_for_predicate(Env, AugmentRecords, MachineVars, UsedComponents, PredWithInv, AugmentedSetOfData, _), | |
| 188 | reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref). | |
| 189 | ||
| 190 | %% reset_old_preferences(+OldKodkodPref, +OldTimeOutPref, +OldRandPref). | |
| 191 | % | |
| 192 | reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref) :- | |
| 193 | set_preference(try_kodkod_on_load, OldKodkodPref), | |
| 194 | set_preference(time_out, OldTimeOutPref), | |
| 195 | set_preference(randomise_enumeration_order, OldRandPref). | |
| 196 | ||
| 197 | %% set_desired_preferences(-OldKodkodPref, -OldTimeOutPref, -OldRandPref). | |
| 198 | % | |
| 199 | % Set desired preferences and return the old ones. | |
| 200 | set_desired_preferences(Timeout, OldKodkodPref, OldTimeOutPref, OldRandPref) :- | |
| 201 | get_preference(try_kodkod_on_load, OldKodkodPref), | |
| 202 | set_preference(try_kodkod_on_load, false), | |
| 203 | get_preference(randomise_enumeration_order, OldRandPref), | |
| 204 | set_preference(randomise_enumeration_order, true), | |
| 205 | get_preference(time_out, OldTimeOutPref), | |
| 206 | set_preference(time_out, Timeout). |