| 1 | :- module(predicate_data_generator, [generate_synthesis_data_from_predicate_raw/3, | |
| 2 | generate_synthesis_data_from_predicate_raw/5, | |
| 3 | generate_synthesis_data_from_predicate_untyped/5]). | |
| 4 | ||
| 5 | :- use_module(library(lists)). | |
| 6 | :- use_module(library(random)). | |
| 7 | :- use_module(library(sets), [intersect/2]). | |
| 8 | ||
| 9 | :- use_module(probsrc(bmachine), [b_load_machine_from_file/1, | |
| 10 | b_get_machine_variables/1, | |
| 11 | b_get_all_used_identifiers/1, | |
| 12 | b_get_main_filename/1, | |
| 13 | bmachine_is_precompiled/0, | |
| 14 | b_machine_precompile/0]). | |
| 15 | :- use_module(probsrc(parsercall), [parse_formula/2]). | |
| 16 | :- use_module(probsrc(translate), [translate_bexpression/2]). | |
| 17 | :- use_module(probsrc(solver_interface), [solve_predicate/5, | |
| 18 | type_check_in_machine_context/2]). | |
| 19 | :- use_module(probsrc(bsyntaxtree), [conjunct_predicates/2, | |
| 20 | get_texpr_info/2, | |
| 21 | find_identifier_uses/3, | |
| 22 | find_typed_identifier_uses/3]). | |
| 23 | :- use_module(probsrc(preferences), [set_preference/2, | |
| 24 | get_preference/2]). | |
| 25 | ||
| 26 | :- use_module(probsrc('por/b_simplifier')). | |
| 27 | :- use_module(synthesis('deep_learning/ground_truth')). | |
| 28 | :- use_module(synthesis('deep_learning/b_machine_identifier_normalization')). | |
| 29 | :- use_module(synthesis(synthesis_util), [get_input_nodes_from_bindings/2, | |
| 30 | create_equality_nodes_from_example/2, | |
| 31 | b_get_typed_invariant_from_machine/1]). | |
| 32 | ||
| 33 | min_amount_of_examples(3). | |
| 34 | max_amount_of_examples(positive, 8). | |
| 35 | max_amount_of_examples(negative, 8). | |
| 36 | ||
| 37 | augment_records(5). | |
| 38 | solver_timeout_ms(10000). | |
| 39 | ||
| 40 | get_random_amount_of_examples((PAmountOfExamples,NAmountOfExamples)) :- | |
| 41 | min_amount_of_examples(MinAmountOfExamples), | |
| 42 | max_amount_of_examples(positive, PTempMaxAmountOfExamples), | |
| 43 | max_amount_of_examples(negative, NTempMaxAmountOfExamples), | |
| 44 | PMaxAmountOfExamples is PTempMaxAmountOfExamples+1, | |
| 45 | NMaxAmountOfExamples is NTempMaxAmountOfExamples+1, | |
| 46 | random(MinAmountOfExamples, PMaxAmountOfExamples, PAmountOfExamples), | |
| 47 | random(MinAmountOfExamples, NMaxAmountOfExamples, NAmountOfExamples). | |
| 48 | ||
| 49 | random_list_of_numbers(0, Acc, Acc) :- | |
| 50 | !. | |
| 51 | random_list_of_numbers(C, Acc, L) :- | |
| 52 | C1 is C-1, | |
| 53 | get_random_amount_of_examples(R), | |
| 54 | \+ member(R, Acc), | |
| 55 | !, | |
| 56 | random_list_of_numbers(C1, [R|Acc], L). | |
| 57 | random_list_of_numbers(C, Acc, L) :- | |
| 58 | random_list_of_numbers(C, Acc, L). | |
| 59 | ||
| 60 | exclude_solution([], ExclusionPred, ExclusionPred) :- | |
| 61 | !. | |
| 62 | exclude_solution(State, ExclusionPred, NewExclusionPred) :- | |
| 63 | create_equality_nodes_from_example(State, EqualityNodes), | |
| 64 | conjunct_predicates(EqualityNodes, EqualityConj), | |
| 65 | truth_or_exclude(EqualityConj,ExclusionPred, NewExclusionPred). | |
| 66 | ||
| 67 | truth_or_exclude(b(truth,pred,[]), ExclusionPred, ExclusionPred) :- | |
| 68 | !. | |
| 69 | truth_or_exclude(EqualityConj, ExclusionPred, b(conjunct(b(negation(EqualityConj),pred,[]),ExclusionPred),pred,[])). | |
| 70 | ||
| 71 | filter_machine_var_states([], _, Acc, Acc). | |
| 72 | filter_machine_var_states([VarState|T], MachineVars, Acc, States) :- | |
| 73 | get_texpr_info(VarState, Info), | |
| 74 | member(synthesis(machinevar, _VarName), Info), | |
| 75 | % operation parameters are also marked as machinevar but not in MachineVars | |
| 76 | %member(b(identifier(VarName), _, _), MachineVars), | |
| 77 | !, | |
| 78 | filter_machine_var_states(T, MachineVars, [VarState|Acc], States). | |
| 79 | filter_machine_var_states([_|T], MachineVars, Acc, States) :- | |
| 80 | filter_machine_var_states(T, MachineVars, Acc, States). | |
| 81 | ||
| 82 | map_translate_var_state(Env, [], [], Env). | |
| 83 | map_translate_var_state(Env, [Ast|T], [(VarName,Type,PrettyAst)|NT], NewEnv) :- | |
| 84 | normalize_ids_in_b_ast(Env, Ast, NAst, Env1), | |
| 85 | NAst = b(_, Type, Info), | |
| 86 | member(synthesis(machinevar, VarName), Info), | |
| 87 | translate_bexpression(NAst, PrettyAst), | |
| 88 | map_translate_var_state(Env1, T, NT, NewEnv). | |
| 89 | ||
| 90 | get_amount_of_states_for_predicate(Env, 0, _, _, _, Acc, Acc, Env) :- | |
| 91 | !. | |
| 92 | get_amount_of_states_for_predicate(Env, AmountOfExamples, ExclusionPred, MachineVars, PredicateAst, Acc, ListOfExamples, NewEnv) :- | |
| 93 | AmountOfExamples1 is AmountOfExamples-1, | |
| 94 | solve_predicate(b(conjunct(ExclusionPred,PredicateAst),pred,[]), _, 1, [force_evaluation], Solution), | |
| 95 | Solution = solution(Bindings), | |
| 96 | get_input_nodes_from_bindings(Bindings, TempState), | |
| 97 | filter_machine_var_states(TempState, MachineVars, [], State), | |
| 98 | State \== [], | |
| 99 | exclude_solution(State, ExclusionPred, NewExclusionPred), | |
| 100 | map_translate_var_state(Env, State, PrettyState, Env1), | |
| 101 | !, | |
| 102 | get_amount_of_states_for_predicate(Env1, AmountOfExamples1, NewExclusionPred, MachineVars, PredicateAst, [PrettyState|Acc], ListOfExamples, NewEnv). | |
| 103 | % cancel if no solution found | |
| 104 | get_amount_of_states_for_predicate(Env, _, _, _, _, Acc, Acc, Env). | |
| 105 | ||
| 106 | get_augmented_states_for_predicate(Env, AR, MachineVars, UsedComponents, PredicateAst, TypingPredicate, AugmentedSetOfData, NewEnv) :- | |
| 107 | random_list_of_numbers(AR, [], RandomNrList), | |
| 108 | get_augmented_states_for_predicate_rand(Env, RandomNrList, MachineVars, UsedComponents, PredicateAst, TypingPredicate, AugmentedSetOfData, NewEnv). | |
| 109 | ||
| 110 | get_augmented_states_for_predicate_rand(Env, [], _, _, _, _, [], Env). | |
| 111 | get_augmented_states_for_predicate_rand(Env, [AmountOfExamples|T], MachineVars, UsedComponents, PredicateAst, TypingPredicate, [(PositiveStates,NegativeStates,UsedComponents)|NT], NewEnv) :- | |
| 112 | AmountOfExamples = (PAmountOfExamples,NAmountOfExamples), | |
| 113 | Pos = b(conjunct(TypingPredicate,PredicateAst),pred,[]), | |
| 114 | Neg = b(conjunct(TypingPredicate,b(negation(PredicateAst),pred,[])),pred,[]), | |
| 115 | get_amount_of_states_for_predicate(Env, PAmountOfExamples, b(truth,pred,[]), MachineVars, Pos, [], PositiveStates, Env1), | |
| 116 | get_amount_of_states_for_predicate(Env1, NAmountOfExamples, b(truth,pred,[]), MachineVars, Neg, [], NegativeStates, Env2), | |
| 117 | length(PositiveStates, LPos), | |
| 118 | length(NegativeStates, LNeg), | |
| 119 | Amount is LPos+LNeg, | |
| 120 | min_amount_of_examples(MinAmountOfExamples), | |
| 121 | Amount >= MinAmountOfExamples, | |
| 122 | !, | |
| 123 | get_augmented_states_for_predicate_rand(Env2, T, MachineVars, UsedComponents, PredicateAst, TypingPredicate, NT, NewEnv). | |
| 124 | get_augmented_states_for_predicate_rand(Env, [_|T], MachineVars, UsedComponents, PredicateAst, TypingPredicate, NT, NewEnv) :- | |
| 125 | get_augmented_states_for_predicate_rand(Env, T, MachineVars, UsedComponents, PredicateAst, TypingPredicate, NT, NewEnv). | |
| 126 | ||
| 127 | blacklist_component(int_set). | |
| 128 | blacklist_component(integer_set). | |
| 129 | blacklist_component(nat_set). | |
| 130 | blacklist_component(nat1_set). | |
| 131 | blacklist_component(natural_set). | |
| 132 | blacklist_component(natural1_set). | |
| 133 | blacklist_component(forall). | |
| 134 | blacklist_component(exists). | |
| 135 | blacklist_component(comprehension_set). | |
| 136 | blacklist_component(total_function). | |
| 137 | blacklist_component(total_surjection). | |
| 138 | blacklist_component(total_injection). | |
| 139 | blacklist_component(total_relation). | |
| 140 | blacklist_component(total_surjection_relation). | |
| 141 | blacklist_component(surjection_relation). | |
| 142 | blacklist_component(partial_function). | |
| 143 | blacklist_component(partial_injection). | |
| 144 | blacklist_component(partial_surjection). | |
| 145 | blacklist_component(partial_bijection). | |
| 146 | blacklist_component(lambda). | |
| 147 | blacklist_component(quantified_union). | |
| 148 | blacklist_component(quantified_intersection). | |
| 149 | %blacklist_component(function). | |
| 150 | %blacklist_component(general_sum). | |
| 151 | %blacklist_component(general_product). | |
| 152 | ||
| 153 | %% filter_predicate(+UsedIds, +TPredicateAst, -PredicateAst). | |
| 154 | % | |
| 155 | % Remove constraints that do not refer to any identifier in UsedIds. | |
| 156 | % For instance, 'x: NAT & x < 10 & NAT \/ NAT1 = NAT' will be reduced to 'x: NAT & x < 10' with x being in UsedIds. | |
| 157 | % Remove constraints that use components from blacklist_component/1 | |
| 158 | filter_predicate(UsedIds, Ast, NewAst) :- | |
| 159 | Ast = b(Node,_,_), | |
| 160 | Node =.. [Functor, Lhs, Rhs], | |
| 161 | ( Functor == conjunct | |
| 162 | ; Functor == disjunct | |
| 163 | ; Functor == equivalence | |
| 164 | ; Functor == implication), | |
| 165 | filter_predicate_binary(UsedIds, Ast, Functor, Lhs, Rhs, Clean), | |
| 166 | !, | |
| 167 | NewAst = Clean. | |
| 168 | % remove constraint if contains component in blacklist_component/1 | |
| 169 | filter_predicate(UsedIds, Ast, NewAst) :- | |
| 170 | Ast = b(Node,_,_), | |
| 171 | Node =.. [_, Lhs, Rhs], | |
| 172 | get_used_components(Lhs, [], LhsUsedComponents), | |
| 173 | get_used_components(Rhs, [], RhsUsedComponents), | |
| 174 | findall(BlacklistCmpt, (member(BlacklistCmpt, LhsUsedComponents), blacklist_component(BlacklistCmpt)), LhsBlacklistCmpts), | |
| 175 | findall(BlacklistCmpt, (member(BlacklistCmpt, RhsUsedComponents), blacklist_component(BlacklistCmpt)), RhsBlacklistCmpts), | |
| 176 | !, | |
| 177 | ( LhsBlacklistCmpts == [], | |
| 178 | RhsBlacklistCmpts == [] | |
| 179 | -> ast_uses_id_from_list(UsedIds, Ast), | |
| 180 | NewAst = Ast | |
| 181 | ; fail | |
| 182 | ). | |
| 183 | filter_predicate(UsedIds, Negation, NegNewAst) :- | |
| 184 | Negation = b(negation(Ast),pred,_), | |
| 185 | !, | |
| 186 | filter_predicate(UsedIds, Ast, NewAst), | |
| 187 | NegNewAst = b(negation(NewAst),pred,[]). | |
| 188 | % fail if ast does not use any ids from UsedIds | |
| 189 | filter_predicate(UsedIds, Ast, NewAst) :- | |
| 190 | ast_uses_id_from_list(UsedIds, Ast), | |
| 191 | get_used_components(Ast, [], UsedComponents), | |
| 192 | findall(BlacklistCmpt, (member(BlacklistCmpt, UsedComponents), blacklist_component(BlacklistCmpt)), BlacklistCmpts), | |
| 193 | BlacklistCmpts == [], | |
| 194 | NewAst = Ast. | |
| 195 | ||
| 196 | filter_predicate_binary(UsedIds, b(_,Type,Info), Functor, Lhs, Rhs, NewNode) :- | |
| 197 | % if lhs would be empty (or truth) keep rhs only and vice versa | |
| 198 | ( filter_predicate(UsedIds, Lhs, NLhs) | |
| 199 | -> ( filter_predicate(UsedIds, Rhs, NRhs) | |
| 200 | -> NNode =.. [Functor, NLhs, NRhs], | |
| 201 | NewNode = b(NNode,Type,Info) | |
| 202 | ; NewNode = NLhs | |
| 203 | ) | |
| 204 | ; ( filter_predicate(UsedIds, Rhs, NRhs) | |
| 205 | -> NewNode = NRhs | |
| 206 | % fail if both sides do not use any ids from UsedIds (predicate would be truth) | |
| 207 | ; fail | |
| 208 | ) | |
| 209 | ),!. | |
| 210 | ||
| 211 | ast_uses_id_from_list(UsedIds, Ast) :- | |
| 212 | Ast = b(Node,_,_), | |
| 213 | functor(Node, Functor, _), | |
| 214 | \+ blacklist_component(Functor), | |
| 215 | bsyntaxtree:find_identifier_uses(Ast, [], AstUsedIds), | |
| 216 | !, | |
| 217 | intersect(UsedIds, AstUsedIds). | |
| 218 | ||
| 219 | get_used_components(b(Node,_,_), Acc, [Functor|Acc]) :- | |
| 220 | functor(Node, Functor, 0),!. | |
| 221 | get_used_components(b(Node,_,_), Acc, UsedComponents) :- | |
| 222 | Node =.. [Component|Args], | |
| 223 | map_get_used_components(Args, [Component|Acc], UsedComponents),!. | |
| 224 | get_used_components(AtomNr, Acc, Acc) :- | |
| 225 | functor(AtomNr, _, 0). | |
| 226 | ||
| 227 | map_get_used_components([], Acc, Acc). | |
| 228 | map_get_used_components([Arg|T], Acc, UsedComponents) :- | |
| 229 | get_used_components(Arg, Acc, NewAcc), | |
| 230 | map_get_used_components(T, NewAcc, UsedComponents). | |
| 231 | ||
| 232 | %% generate_synthesis_data_from_predicate_raw(+MachinePath, +RawPredicate, -GeneratedData). | |
| 233 | % | |
| 234 | % Generate positive and negative examples for a pretty-printed predicate and | |
| 235 | % extract the ground truth B components as used by the program synthesis tool. | |
| 236 | % Data is a list of triples (PositiveExamples, NegativeExamples, GroundTruth) considering | |
| 237 | % data augmentation. Each example is a set of triples (MachineVar, Type, Value). | |
| 238 | % Note: Predicate assumes that the B or Event-B machine is loaded that RawPredicate originates from. | |
| 239 | % Otherwise, fails silently. | |
| 240 | generate_synthesis_data_from_predicate_raw(MachinePath, RawPredicate, AugmentedSetOfData) :- | |
| 241 | augment_records(AR), | |
| 242 | solver_timeout_ms(SolverTimeoutMs), | |
| 243 | generate_synthesis_data_from_predicate_raw(MachinePath, AR, SolverTimeoutMs, RawPredicate, AugmentedSetOfData). | |
| 244 | ||
| 245 | :- dynamic normalized_id_name_mapping/4. | |
| 246 | :- volatile normalized_id_name_mapping/4. | |
| 247 | ||
| 248 | get_normalized_id_name_mapping_stateful(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames) :- | |
| 249 | normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames). | |
| 250 | get_normalized_id_name_mapping_stateful(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames) :- | |
| 251 | \+ normalized_id_name_mapping(_, _, _, _), | |
| 252 | get_normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames), | |
| 253 | asserta(normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames)). | |
| 254 | ||
| 255 | load_b_machine_if_unloaded(MachinePath) :- | |
| 256 | bmachine_is_precompiled, | |
| 257 | b_get_main_filename(LoadedMachinePath), | |
| 258 | atom_concat(_, MachinePath, LoadedMachinePath), | |
| 259 | !. | |
| 260 | load_b_machine_if_unloaded(MachinePath) :- | |
| 261 | retractall(normalized_id_name_mapping(_, _, _, _)), | |
| 262 | b_load_machine_from_file(MachinePath), | |
| 263 | b_machine_precompile. | |
| 264 | ||
| 265 | get_id_name(b(identifier(Name),_,_),Name). | |
| 266 | ||
| 267 | %% generate_synthesis_data_from_predicate_raw(+MachinePath, +AugmentRecords, +SolverTimeoutMs, +RawPredicate, -AugmentedSetOfData). | |
| 268 | % | |
| 269 | % AugmentRecords is the amount of data augmentations. | |
| 270 | generate_synthesis_data_from_predicate_raw(MachinePath, AugmentRecords, SolverTimeoutMs, RawPredicate, AugmentedSetOfData) :- | |
| 271 | atom_codes(RawPredicate, RawPredicateCodes), | |
| 272 | parse_formula(RawPredicateCodes, UntypedPredAst), | |
| 273 | generate_synthesis_data_from_predicate_untyped(MachinePath, AugmentRecords, SolverTimeoutMs, UntypedPredAst, AugmentedSetOfData). | |
| 274 | ||
| 275 | %% generate_synthesis_data_from_predicate_untyped(+MachinePath, +AugmentRecords, +SolverTimeoutMs, +UntypedPredAst, -AugmentedSetOfData). | |
| 276 | % | |
| 277 | generate_synthesis_data_from_predicate_untyped(MachinePath, AugmentRecords, SolverTimeoutMs, UntypedPredAst, AugmentedSetOfData) :- | |
| 278 | load_b_machine_if_unloaded(MachinePath), | |
| 279 | type_check_in_machine_context([UntypedPredAst], TypedPredAsts), | |
| 280 | TypedPredAsts = [TypedPredAst], | |
| 281 | generate_synthesis_data_from_predicate_ast(MachinePath, AugmentRecords, SolverTimeoutMs, TypedPredAst, AugmentedSetOfData). | |
| 282 | ||
| 283 | generate_synthesis_data_from_predicate_ast(MachinePath, AugmentRecords, SolverTimeoutMs, PredicateAst, AugmentedSetOfData) :- | |
| 284 | preferences:temporary_set_preference(optimize_ast, false), | |
| 285 | preferences:temporary_set_preference(normalize_ast, true), | |
| 286 | % should we also set normalize_ast_sort_commutative to true? | |
| 287 | load_b_machine_if_unloaded(MachinePath), | |
| 288 | \+ current_machine_uses_records, | |
| 289 | retractall(tools:id_counter(_)), | |
| 290 | set_desired_preferences(SolverTimeoutMs, OldKodkodPref, OldTimeOutPref, OldRandPref), | |
| 291 | b_get_machine_variables(MachineVars), | |
| 292 | maplist(get_id_name, MachineVars, MachineVarNames), | |
| 293 | % following includes constraints not using machine variables | |
| 294 | %b_get_all_used_identifiers(AllUsedIds), | |
| 295 | b_ast_cleanup:clean_up_pred(PredicateAst, _, CleanPred), | |
| 296 | b_simplifier:simplify_b_predicate(CleanPred, SimplifiedPred), | |
| 297 | filter_predicate(MachineVarNames, SimplifiedPred, FilteredPredicateAst), | |
| 298 | get_library_components_from_pred_or_expr(FilteredPredicateAst, UsedComponents), | |
| 299 | get_normalized_id_name_mapping_stateful(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames), | |
| 300 | Env = [[], NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames], | |
| 301 | %b_get_typed_invariant_from_machine(Invariant), | |
| 302 | find_typed_identifier_uses(FilteredPredicateAst, [], UsedIds), | |
| 303 | translate:generate_typing_predicates(UsedIds, TypingPredicates), | |
| 304 | conjunct_predicates(TypingPredicates, TypingPredicate), | |
| 305 | %PredWithTypes = b(conjunct(TypingPredicate, FilteredPredicateAst),pred,[]), | |
| 306 | get_augmented_states_for_predicate(Env, AugmentRecords, MachineVars, UsedComponents, FilteredPredicateAst, TypingPredicate, AugmentedSetOfData, _), | |
| 307 | reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref), | |
| 308 | preferences:reset_temporary_preference(optimize_ast), | |
| 309 | preferences:reset_temporary_preference(normalize_ast). | |
| 310 | ||
| 311 | %% reset_old_preferences(+OldKodkodPref, +OldTimeOutPref, +OldRandPref). | |
| 312 | % | |
| 313 | reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref) :- | |
| 314 | set_preference(try_kodkod_on_load, OldKodkodPref), | |
| 315 | set_preference(time_out, OldTimeOutPref), | |
| 316 | set_preference(randomise_enumeration_order, OldRandPref). | |
| 317 | ||
| 318 | %% set_desired_preferences(-OldKodkodPref, -OldTimeOutPref, -OldRandPref). | |
| 319 | % | |
| 320 | % Set desired preferences and return the old ones. | |
| 321 | set_desired_preferences(Timeout, OldKodkodPref, OldTimeOutPref, OldRandPref) :- | |
| 322 | get_preference(try_kodkod_on_load, OldKodkodPref), | |
| 323 | set_preference(try_kodkod_on_load, false), | |
| 324 | get_preference(randomise_enumeration_order, OldRandPref), | |
| 325 | set_preference(randomise_enumeration_order, true), | |
| 326 | get_preference(time_out, OldTimeOutPref), | |
| 327 | set_preference(time_out, Timeout). |