1 | :- module(predicate_data_generator, [generate_synthesis_data_from_predicate_raw/3, | |
2 | generate_synthesis_data_from_predicate_raw/5, | |
3 | generate_synthesis_data_from_predicate_untyped/5]). | |
4 | ||
5 | :- use_module(library(lists)). | |
6 | :- use_module(library(random)). | |
7 | :- use_module(library(sets), [intersect/2]). | |
8 | ||
9 | :- use_module(probsrc(bmachine), [b_load_machine_from_file/1, | |
10 | b_get_machine_variables/1, | |
11 | b_get_all_used_identifiers/1, | |
12 | b_get_main_filename/1, | |
13 | bmachine_is_precompiled/0, | |
14 | b_machine_precompile/0]). | |
15 | :- use_module(probsrc(parsercall), [parse_formula/2]). | |
16 | :- use_module(probsrc(translate), [translate_bexpression/2]). | |
17 | :- use_module(probsrc(solver_interface), [solve_predicate/5, | |
18 | type_check_in_machine_context/2]). | |
19 | :- use_module(probsrc(bsyntaxtree), [conjunct_predicates/2, | |
20 | get_texpr_info/2, | |
21 | find_identifier_uses/3, | |
22 | find_typed_identifier_uses/3]). | |
23 | :- use_module(probsrc(preferences), [set_preference/2, | |
24 | get_preference/2]). | |
25 | ||
26 | :- use_module(probsrc('por/b_simplifier')). | |
27 | :- use_module(synthesis('deep_learning/ground_truth')). | |
28 | :- use_module(synthesis('deep_learning/b_machine_identifier_normalization')). | |
29 | :- use_module(synthesis(synthesis_util), [get_input_nodes_from_bindings/2, | |
30 | create_equality_nodes_from_example/2, | |
31 | b_get_typed_invariant_from_machine/1]). | |
32 | ||
33 | min_amount_of_examples(3). | |
34 | max_amount_of_examples(positive, 8). | |
35 | max_amount_of_examples(negative, 8). | |
36 | ||
37 | augment_records(5). | |
38 | solver_timeout_ms(10000). | |
39 | ||
40 | get_random_amount_of_examples((PAmountOfExamples,NAmountOfExamples)) :- | |
41 | min_amount_of_examples(MinAmountOfExamples), | |
42 | max_amount_of_examples(positive, PTempMaxAmountOfExamples), | |
43 | max_amount_of_examples(negative, NTempMaxAmountOfExamples), | |
44 | PMaxAmountOfExamples is PTempMaxAmountOfExamples+1, | |
45 | NMaxAmountOfExamples is NTempMaxAmountOfExamples+1, | |
46 | random(MinAmountOfExamples, PMaxAmountOfExamples, PAmountOfExamples), | |
47 | random(MinAmountOfExamples, NMaxAmountOfExamples, NAmountOfExamples). | |
48 | ||
49 | random_list_of_numbers(0, Acc, Acc) :- | |
50 | !. | |
51 | random_list_of_numbers(C, Acc, L) :- | |
52 | C1 is C-1, | |
53 | get_random_amount_of_examples(R), | |
54 | \+ member(R, Acc), | |
55 | !, | |
56 | random_list_of_numbers(C1, [R|Acc], L). | |
57 | random_list_of_numbers(C, Acc, L) :- | |
58 | random_list_of_numbers(C, Acc, L). | |
59 | ||
60 | exclude_solution([], ExclusionPred, ExclusionPred) :- | |
61 | !. | |
62 | exclude_solution(State, ExclusionPred, NewExclusionPred) :- | |
63 | create_equality_nodes_from_example(State, EqualityNodes), | |
64 | conjunct_predicates(EqualityNodes, EqualityConj), | |
65 | truth_or_exclude(EqualityConj,ExclusionPred, NewExclusionPred). | |
66 | ||
67 | truth_or_exclude(b(truth,pred,[]), ExclusionPred, ExclusionPred) :- | |
68 | !. | |
69 | truth_or_exclude(EqualityConj, ExclusionPred, b(conjunct(b(negation(EqualityConj),pred,[]),ExclusionPred),pred,[])). | |
70 | ||
71 | filter_machine_var_states([], _, Acc, Acc). | |
72 | filter_machine_var_states([VarState|T], MachineVars, Acc, States) :- | |
73 | get_texpr_info(VarState, Info), | |
74 | member(synthesis(machinevar, _VarName), Info), | |
75 | % operation parameters are also marked as machinevar but not in MachineVars | |
76 | %member(b(identifier(VarName), _, _), MachineVars), | |
77 | !, | |
78 | filter_machine_var_states(T, MachineVars, [VarState|Acc], States). | |
79 | filter_machine_var_states([_|T], MachineVars, Acc, States) :- | |
80 | filter_machine_var_states(T, MachineVars, Acc, States). | |
81 | ||
82 | map_translate_var_state(Env, [], [], Env). | |
83 | map_translate_var_state(Env, [Ast|T], [(VarName,Type,PrettyAst)|NT], NewEnv) :- | |
84 | normalize_ids_in_b_ast(Env, Ast, NAst, Env1), | |
85 | NAst = b(_, Type, Info), | |
86 | member(synthesis(machinevar, VarName), Info), | |
87 | translate_bexpression(NAst, PrettyAst), | |
88 | map_translate_var_state(Env1, T, NT, NewEnv). | |
89 | ||
90 | get_amount_of_states_for_predicate(Env, 0, _, _, _, Acc, Acc, Env) :- | |
91 | !. | |
92 | get_amount_of_states_for_predicate(Env, AmountOfExamples, ExclusionPred, MachineVars, PredicateAst, Acc, ListOfExamples, NewEnv) :- | |
93 | AmountOfExamples1 is AmountOfExamples-1, | |
94 | solve_predicate(b(conjunct(ExclusionPred,PredicateAst),pred,[]), _, 1, [force_evaluation], Solution), | |
95 | Solution = solution(Bindings), | |
96 | get_input_nodes_from_bindings(Bindings, TempState), | |
97 | filter_machine_var_states(TempState, MachineVars, [], State), | |
98 | State \== [], | |
99 | exclude_solution(State, ExclusionPred, NewExclusionPred), | |
100 | map_translate_var_state(Env, State, PrettyState, Env1), | |
101 | !, | |
102 | get_amount_of_states_for_predicate(Env1, AmountOfExamples1, NewExclusionPred, MachineVars, PredicateAst, [PrettyState|Acc], ListOfExamples, NewEnv). | |
103 | % cancel if no solution found | |
104 | get_amount_of_states_for_predicate(Env, _, _, _, _, Acc, Acc, Env). | |
105 | ||
106 | get_augmented_states_for_predicate(Env, AR, MachineVars, UsedComponents, PredicateAst, TypingPredicate, AugmentedSetOfData, NewEnv) :- | |
107 | random_list_of_numbers(AR, [], RandomNrList), | |
108 | get_augmented_states_for_predicate_rand(Env, RandomNrList, MachineVars, UsedComponents, PredicateAst, TypingPredicate, AugmentedSetOfData, NewEnv). | |
109 | ||
110 | get_augmented_states_for_predicate_rand(Env, [], _, _, _, _, [], Env). | |
111 | get_augmented_states_for_predicate_rand(Env, [AmountOfExamples|T], MachineVars, UsedComponents, PredicateAst, TypingPredicate, [(PositiveStates,NegativeStates,UsedComponents)|NT], NewEnv) :- | |
112 | AmountOfExamples = (PAmountOfExamples,NAmountOfExamples), | |
113 | Pos = b(conjunct(TypingPredicate,PredicateAst),pred,[]), | |
114 | Neg = b(conjunct(TypingPredicate,b(negation(PredicateAst),pred,[])),pred,[]), | |
115 | get_amount_of_states_for_predicate(Env, PAmountOfExamples, b(truth,pred,[]), MachineVars, Pos, [], PositiveStates, Env1), | |
116 | get_amount_of_states_for_predicate(Env1, NAmountOfExamples, b(truth,pred,[]), MachineVars, Neg, [], NegativeStates, Env2), | |
117 | length(PositiveStates, LPos), | |
118 | length(NegativeStates, LNeg), | |
119 | Amount is LPos+LNeg, | |
120 | min_amount_of_examples(MinAmountOfExamples), | |
121 | Amount >= MinAmountOfExamples, | |
122 | !, | |
123 | get_augmented_states_for_predicate_rand(Env2, T, MachineVars, UsedComponents, PredicateAst, TypingPredicate, NT, NewEnv). | |
124 | get_augmented_states_for_predicate_rand(Env, [_|T], MachineVars, UsedComponents, PredicateAst, TypingPredicate, NT, NewEnv) :- | |
125 | get_augmented_states_for_predicate_rand(Env, T, MachineVars, UsedComponents, PredicateAst, TypingPredicate, NT, NewEnv). | |
126 | ||
127 | blacklist_component(int_set). | |
128 | blacklist_component(integer_set). | |
129 | blacklist_component(nat_set). | |
130 | blacklist_component(nat1_set). | |
131 | blacklist_component(natural_set). | |
132 | blacklist_component(natural1_set). | |
133 | blacklist_component(forall). | |
134 | blacklist_component(exists). | |
135 | blacklist_component(comprehension_set). | |
136 | blacklist_component(total_function). | |
137 | blacklist_component(total_surjection). | |
138 | blacklist_component(total_injection). | |
139 | blacklist_component(total_relation). | |
140 | blacklist_component(total_surjection_relation). | |
141 | blacklist_component(surjection_relation). | |
142 | blacklist_component(partial_function). | |
143 | blacklist_component(partial_injection). | |
144 | blacklist_component(partial_surjection). | |
145 | blacklist_component(partial_bijection). | |
146 | blacklist_component(lambda). | |
147 | blacklist_component(quantified_union). | |
148 | blacklist_component(quantified_intersection). | |
149 | %blacklist_component(function). | |
150 | %blacklist_component(general_sum). | |
151 | %blacklist_component(general_product). | |
152 | ||
153 | %% filter_predicate(+UsedIds, +TPredicateAst, -PredicateAst). | |
154 | % | |
155 | % Remove constraints that do not refer to any identifier in UsedIds. | |
156 | % For instance, 'x: NAT & x < 10 & NAT \/ NAT1 = NAT' will be reduced to 'x: NAT & x < 10' with x being in UsedIds. | |
157 | % Remove constraints that use components from blacklist_component/1 | |
158 | filter_predicate(UsedIds, Ast, NewAst) :- | |
159 | Ast = b(Node,_,_), | |
160 | Node =.. [Functor, Lhs, Rhs], | |
161 | ( Functor == conjunct | |
162 | ; Functor == disjunct | |
163 | ; Functor == equivalence | |
164 | ; Functor == implication), | |
165 | filter_predicate_binary(UsedIds, Ast, Functor, Lhs, Rhs, Clean), | |
166 | !, | |
167 | NewAst = Clean. | |
168 | % remove constraint if contains component in blacklist_component/1 | |
169 | filter_predicate(UsedIds, Ast, NewAst) :- | |
170 | Ast = b(Node,_,_), | |
171 | Node =.. [_, Lhs, Rhs], | |
172 | get_used_components(Lhs, [], LhsUsedComponents), | |
173 | get_used_components(Rhs, [], RhsUsedComponents), | |
174 | findall(BlacklistCmpt, (member(BlacklistCmpt, LhsUsedComponents), blacklist_component(BlacklistCmpt)), LhsBlacklistCmpts), | |
175 | findall(BlacklistCmpt, (member(BlacklistCmpt, RhsUsedComponents), blacklist_component(BlacklistCmpt)), RhsBlacklistCmpts), | |
176 | !, | |
177 | ( LhsBlacklistCmpts == [], | |
178 | RhsBlacklistCmpts == [] | |
179 | -> ast_uses_id_from_list(UsedIds, Ast), | |
180 | NewAst = Ast | |
181 | ; fail | |
182 | ). | |
183 | filter_predicate(UsedIds, Negation, NegNewAst) :- | |
184 | Negation = b(negation(Ast),pred,_), | |
185 | !, | |
186 | filter_predicate(UsedIds, Ast, NewAst), | |
187 | NegNewAst = b(negation(NewAst),pred,[]). | |
188 | % fail if ast does not use any ids from UsedIds | |
189 | filter_predicate(UsedIds, Ast, NewAst) :- | |
190 | ast_uses_id_from_list(UsedIds, Ast), | |
191 | get_used_components(Ast, [], UsedComponents), | |
192 | findall(BlacklistCmpt, (member(BlacklistCmpt, UsedComponents), blacklist_component(BlacklistCmpt)), BlacklistCmpts), | |
193 | BlacklistCmpts == [], | |
194 | NewAst = Ast. | |
195 | ||
196 | filter_predicate_binary(UsedIds, b(_,Type,Info), Functor, Lhs, Rhs, NewNode) :- | |
197 | % if lhs would be empty (or truth) keep rhs only and vice versa | |
198 | ( filter_predicate(UsedIds, Lhs, NLhs) | |
199 | -> ( filter_predicate(UsedIds, Rhs, NRhs) | |
200 | -> NNode =.. [Functor, NLhs, NRhs], | |
201 | NewNode = b(NNode,Type,Info) | |
202 | ; NewNode = NLhs | |
203 | ) | |
204 | ; ( filter_predicate(UsedIds, Rhs, NRhs) | |
205 | -> NewNode = NRhs | |
206 | % fail if both sides do not use any ids from UsedIds (predicate would be truth) | |
207 | ; fail | |
208 | ) | |
209 | ),!. | |
210 | ||
211 | ast_uses_id_from_list(UsedIds, Ast) :- | |
212 | Ast = b(Node,_,_), | |
213 | functor(Node, Functor, _), | |
214 | \+ blacklist_component(Functor), | |
215 | bsyntaxtree:find_identifier_uses(Ast, [], AstUsedIds), | |
216 | !, | |
217 | intersect(UsedIds, AstUsedIds). | |
218 | ||
219 | get_used_components(b(Node,_,_), Acc, [Functor|Acc]) :- | |
220 | functor(Node, Functor, 0),!. | |
221 | get_used_components(b(Node,_,_), Acc, UsedComponents) :- | |
222 | Node =.. [Component|Args], | |
223 | map_get_used_components(Args, [Component|Acc], UsedComponents),!. | |
224 | get_used_components(AtomNr, Acc, Acc) :- | |
225 | functor(AtomNr, _, 0). | |
226 | ||
227 | map_get_used_components([], Acc, Acc). | |
228 | map_get_used_components([Arg|T], Acc, UsedComponents) :- | |
229 | get_used_components(Arg, Acc, NewAcc), | |
230 | map_get_used_components(T, NewAcc, UsedComponents). | |
231 | ||
232 | %% generate_synthesis_data_from_predicate_raw(+MachinePath, +RawPredicate, -GeneratedData). | |
233 | % | |
234 | % Generate positive and negative examples for a pretty-printed predicate and | |
235 | % extract the ground truth B components as used by the program synthesis tool. | |
236 | % Data is a list of triples (PositiveExamples, NegativeExamples, GroundTruth) considering | |
237 | % data augmentation. Each example is a set of triples (MachineVar, Type, Value). | |
238 | % Note: Predicate assumes that the B or Event-B machine is loaded that RawPredicate originates from. | |
239 | % Otherwise, fails silently. | |
240 | generate_synthesis_data_from_predicate_raw(MachinePath, RawPredicate, AugmentedSetOfData) :- | |
241 | augment_records(AR), | |
242 | solver_timeout_ms(SolverTimeoutMs), | |
243 | generate_synthesis_data_from_predicate_raw(MachinePath, AR, SolverTimeoutMs, RawPredicate, AugmentedSetOfData). | |
244 | ||
245 | :- dynamic normalized_id_name_mapping/4. | |
246 | :- volatile normalized_id_name_mapping/4. | |
247 | ||
248 | get_normalized_id_name_mapping_stateful(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames) :- | |
249 | normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames). | |
250 | get_normalized_id_name_mapping_stateful(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames) :- | |
251 | \+ normalized_id_name_mapping(_, _, _, _), | |
252 | get_normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames), | |
253 | asserta(normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames)). | |
254 | ||
255 | load_b_machine_if_unloaded(MachinePath) :- | |
256 | bmachine_is_precompiled, | |
257 | b_get_main_filename(LoadedMachinePath), | |
258 | atom_concat(_, MachinePath, LoadedMachinePath), | |
259 | !. | |
260 | load_b_machine_if_unloaded(MachinePath) :- | |
261 | retractall(normalized_id_name_mapping(_, _, _, _)), | |
262 | b_load_machine_from_file(MachinePath), | |
263 | b_machine_precompile. | |
264 | ||
265 | get_id_name(b(identifier(Name),_,_),Name). | |
266 | ||
267 | %% generate_synthesis_data_from_predicate_raw(+MachinePath, +AugmentRecords, +SolverTimeoutMs, +RawPredicate, -AugmentedSetOfData). | |
268 | % | |
269 | % AugmentRecords is the amount of data augmentations. | |
270 | generate_synthesis_data_from_predicate_raw(MachinePath, AugmentRecords, SolverTimeoutMs, RawPredicate, AugmentedSetOfData) :- | |
271 | atom_codes(RawPredicate, RawPredicateCodes), | |
272 | parse_formula(RawPredicateCodes, UntypedPredAst), | |
273 | generate_synthesis_data_from_predicate_untyped(MachinePath, AugmentRecords, SolverTimeoutMs, UntypedPredAst, AugmentedSetOfData). | |
274 | ||
275 | %% generate_synthesis_data_from_predicate_untyped(+MachinePath, +AugmentRecords, +SolverTimeoutMs, +UntypedPredAst, -AugmentedSetOfData). | |
276 | % | |
277 | generate_synthesis_data_from_predicate_untyped(MachinePath, AugmentRecords, SolverTimeoutMs, UntypedPredAst, AugmentedSetOfData) :- | |
278 | load_b_machine_if_unloaded(MachinePath), | |
279 | type_check_in_machine_context([UntypedPredAst], TypedPredAsts), | |
280 | TypedPredAsts = [TypedPredAst], | |
281 | generate_synthesis_data_from_predicate_ast(MachinePath, AugmentRecords, SolverTimeoutMs, TypedPredAst, AugmentedSetOfData). | |
282 | ||
283 | generate_synthesis_data_from_predicate_ast(MachinePath, AugmentRecords, SolverTimeoutMs, PredicateAst, AugmentedSetOfData) :- | |
284 | preferences:temporary_set_preference(optimize_ast, false), | |
285 | preferences:temporary_set_preference(normalize_ast, true), | |
286 | % should we also set normalize_ast_sort_commutative to true? | |
287 | load_b_machine_if_unloaded(MachinePath), | |
288 | \+ current_machine_uses_records, | |
289 | retractall(tools:id_counter(_)), | |
290 | set_desired_preferences(SolverTimeoutMs, OldKodkodPref, OldTimeOutPref, OldRandPref), | |
291 | b_get_machine_variables(MachineVars), | |
292 | maplist(get_id_name, MachineVars, MachineVarNames), | |
293 | % following includes constraints not using machine variables | |
294 | %b_get_all_used_identifiers(AllUsedIds), | |
295 | b_ast_cleanup:clean_up_pred(PredicateAst, _, CleanPred), | |
296 | b_simplifier:simplify_b_predicate(CleanPred, SimplifiedPred), | |
297 | filter_predicate(MachineVarNames, SimplifiedPred, FilteredPredicateAst), | |
298 | get_library_components_from_pred_or_expr(FilteredPredicateAst, UsedComponents), | |
299 | get_normalized_id_name_mapping_stateful(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames), | |
300 | Env = [[], NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames], | |
301 | %b_get_typed_invariant_from_machine(Invariant), | |
302 | find_typed_identifier_uses(FilteredPredicateAst, [], UsedIds), | |
303 | translate:generate_typing_predicates(UsedIds, TypingPredicates), | |
304 | conjunct_predicates(TypingPredicates, TypingPredicate), | |
305 | %PredWithTypes = b(conjunct(TypingPredicate, FilteredPredicateAst),pred,[]), | |
306 | get_augmented_states_for_predicate(Env, AugmentRecords, MachineVars, UsedComponents, FilteredPredicateAst, TypingPredicate, AugmentedSetOfData, _), | |
307 | reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref), | |
308 | preferences:reset_temporary_preference(optimize_ast), | |
309 | preferences:reset_temporary_preference(normalize_ast). | |
310 | ||
311 | %% reset_old_preferences(+OldKodkodPref, +OldTimeOutPref, +OldRandPref). | |
312 | % | |
313 | reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref) :- | |
314 | set_preference(try_kodkod_on_load, OldKodkodPref), | |
315 | set_preference(time_out, OldTimeOutPref), | |
316 | set_preference(randomise_enumeration_order, OldRandPref). | |
317 | ||
318 | %% set_desired_preferences(-OldKodkodPref, -OldTimeOutPref, -OldRandPref). | |
319 | % | |
320 | % Set desired preferences and return the old ones. | |
321 | set_desired_preferences(Timeout, OldKodkodPref, OldTimeOutPref, OldRandPref) :- | |
322 | get_preference(try_kodkod_on_load, OldKodkodPref), | |
323 | set_preference(try_kodkod_on_load, false), | |
324 | get_preference(randomise_enumeration_order, OldRandPref), | |
325 | set_preference(randomise_enumeration_order, true), | |
326 | get_preference(time_out, OldTimeOutPref), | |
327 | set_preference(time_out, Timeout). |