1 % (c) 2022-2025 Lehrstuhl fuer Softwaretechnik und Programmiersprachen,
2 % Heinrich Heine Universitaet Duesseldorf
3 % This software is licenced under EPL 1.0 (http://www.eclipse.org/org/documents/epl-v10.html)
4
5 % MCTS: Monte-Carlo Tree Search
6 % initial version by Michael Leuschel, May 2022
7
8 % requires:
9 % GAME_OVER
10 % GAME_VALUE
11 % GAME_PLAYER
12 % GAME_MCTS_RUNS
13 % GAME_MCTS_TIMEOUT
14 % GAME_MCTS_CACHE_LAST_TREE
15 % GAME_MCTS_MAX_SIMULATION_LENGTH
16
17 % TODO: allow to use for CSP, ...
18 % improve performance of evaluation, preparing state and avoiding repeated unpacking
19
20 :- module(mcts_game_play, [mcts_auto_play/4, mcts_auto_play_available/0,
21 tcltk_gen_last_mcts_tree/2, mcts_tree_available/0]).
22
23
24 :- use_module(probsrc(module_information)).
25 :- module_info(group,misc).
26 :- module_info(description,'This module provides Monte Carlo Tree Search for games.').
27
28 % use_module(extrasrc(mcts_game_play)), current_state_id(I), mcts_auto_play(I,Action,TID,Dest).
29
30 :- use_module(probsrc(error_manager)).
31 :- use_module(probsrc(debug)).
32 :- use_module(probsrc(state_space),[current_state_id/1, transition/4,
33 visited_expression/2]).
34 :- use_module(probsrc(tcltk_interface),[compute_all_transitions_if_necessary/2]).
35
36 game_move(FromID,Action,TransID,DestID) :-
37 compute_all_transitions_if_necessary(FromID,true),
38 transition(FromID,Action,TransID,DestID).
39
40
41
42 :- use_module(library(random),[random_select/3]).
43 % use if the game does not provide random_game_move
44 % drawback: it computes all transitions first, thereby adding all successors
45 % and then selects a random move amongst them
46 % TODO: investigate whether we can implement a random move using random enumeration
47 my_random_game_move(X,Z) :-
48 winning_move(X,DestID),!,
49 Z=DestID. % Mini-Max style optimisation: do not perform random move; pick the winning move directly
50 my_random_game_move(X,Z) :-
51 findall(Y,game_move(X,_,_,Y),List),
52 random_select(Z,List,_).
53
54 winning_move(X,DestID) :-
55 game_move(X,_A,_,DestID), % as we have computed all successors, we can just as well look for a winning one
56 terminal(DestID,_Val,winning). % TODO: extend this detection also to non-terminal nodes?!
57
58 other_player(min,max).
59 other_player(max,min).
60
61 :- use_module(probsrc(state_space_exploration_modes),
62 [compute_heuristic_function_for_state_id/2]).
63 :- use_module(probsrc(eventhandling),[register_event_listener/3]).
64 :- register_event_listener(reset_specification,reset_mcts,'Reset MCTS information').
65 :- register_event_listener(reset_prob,reset_mcts,'Reset MCTS information').
66
67 :- dynamic terminal_node_cache/3.
68 reset_mcts :- retractall(terminal_node_cache(_,_,_)), retractall(saved_mcts_tree(_,_)).
69
70 :- use_module(probsrc(specfile),[state_corresponds_to_fully_setup_b_machine/2]).
71
72 utility(State,Value) :-
73 eval_animation_expression(State,'GAME_VALUE',Res,mandatory),
74 get_number(Res,Value).
75
76 % check if node is terminal and compute utility if it is:
77 terminal(NodeID,Val,WinForParent) :- terminal_node_cache(NodeID,T,WP),!,
78 T \= non_terminal,
79 Val=T, WinForParent=WP.
80 terminal(NodeID,Val,WinForParent) :-
81 visited_expression(NodeID,State),
82 eval_animation_expression(State,'GAME_OVER',Res,mandatory),
83 (is_true(Res)
84 -> (utility(State,Val) -> true
85 ; add_warning(mcts_game_play,'Could not compute utility for terminal node: ',NodeID),
86 Val=0
87 ),
88 (winning_node_for_parent(State,Val) -> WinForParent=winning ; WinForParent=not_directly_winning),
89 assert(terminal_node_cache(NodeID,Val,WinForParent))
90 ; assert(terminal_node_cache(NodeID,non_terminal,not_directly_winning)),
91 fail).
92
93 is_true(pred_true).
94 is_true(true). % XTL
95 is_true('TRUE').
96
97 get_player_in_state_id(NodeID,R) :-
98 visited_expression(NodeID,State),!,
99 get_player(State,R).
100 get_player_in_state_id(NodeID,R) :-
101 add_internal_error('Illegal state:',get_player_in_state_id(NodeID,R)),fail.
102
103 get_player(State,R) :-
104 eval_animation_expression(State,'GAME_PLAYER',Res,mandatory),
105 is_max_player(Res),
106 !, R=max.
107 get_player(_,min).
108
109 :- use_module(probsrc(b_global_sets),[is_b_global_constant/3]).
110 is_max_player(string(Atom)) :- !, is_max_aux(Atom).
111 is_max_player(fd(Nr,Set)) :- is_b_global_constant(Set,Nr,Cst), !,is_max_aux(Cst).
112 is_max_player(pred_true) :- !.
113 is_max_player(pred_false) :- !, fail.
114 is_max_player(Atom) :- is_max_aux(Atom). % XTL
115
116 is_max_aux(Atom) :- is_max_atom(Atom),!.
117 is_max_aux(Atom) :- is_min_atom(Atom),!,fail.
118 is_max_aux(Atom) :- add_error(mcts_game_play,'Illegal GAME_PLAYER value, must be min or max:',Atom),fail.
119
120 is_max_atom(max).
121 is_max_atom('MAX').
122 is_max_atom('Max').
123 is_min_atom(min).
124 is_min_atom('MIN').
125 is_min_atom('Min').
126
127
128 :- use_module(probsrc(xtl_interface),[xtl_game_info/3]).
129
130 % TODO: also provide version which checks if available in state
131 mcts_auto_play_available :-
132 (b_get_machine_animation_expression('GAME_PLAYER',_) -> true
133 ; b_get_machine_animation_expression('GAME_MCTS_RUNS',_) -> true
134 ; xtl_mode -> xtl_game_info('GAME_MCTS_RUNS',_,_) -> true).
135
136 :- use_module(probsrc(bmachine),[b_get_machine_animation_expression/2,b_get_definition/5]).
137 :- use_module(probsrc(specfile),[csp_mode/0, xtl_mode/0]).
138 eval_animation_expression_in_state_id(NodeID,STR,Val,Mandatory) :-
139 visited_expression(NodeID,State),
140 eval_animation_expression(State,STR,Val,Mandatory).
141
142 :- use_module(probsrc(b_interpreter), [ b_compute_explicit_epression_no_wf/6]).
143 eval_animation_expression(State,STR,Val,_) :-
144 b_get_machine_animation_expression(STR,TExpr),!,
145 state_corresponds_to_fully_setup_b_machine(State,BState),
146 b_compute_explicit_epression_no_wf(TExpr,[],BState,Val,'MCTS',0).
147 eval_animation_expression(State,STR,Val,Mandatory) :- xtl_mode,!,
148 (xtl_game_info(STR,State,Res) -> Val=Res
149 ; Mandatory=mandatory,
150 add_warning(mcts_game_play,'Add definition for prob_game_info(Key,State,Val) with Key = ',STR),fail).
151 eval_animation_expression(_,STR,_,_) :-
152 b_get_definition(STR,DefType,_Args,_Body,_Deps),!,
153 (DefType = expression
154 -> add_warning(mcts_game_play,'Please rewrite DEFINITION to use no arguments: ',STR)
155 ; add_warning(mcts_game_play,'Please rewrite DEFINITION to return an expression and use no arguments: ',STR)
156 ),
157 fail.
158 eval_animation_expression(_,STR,_,mandatory) :-
159 add_warning(mcts_game_play,'Please add DEFINITION for: ',STR),fail.
160
161 % TODO: get definitions also from VisB JSON files (e.g., for Event-B, ...)
162
163 % ------------------------------
164
165
166 mcts_auto_play(State,Action,TransID,State2) :-
167 get_animation_expression_nr(State,'GAME_MCTS_RUNS',1000,SimRuns),
168 get_animation_expression_nr(State,'GAME_MCTS_TIMEOUT',5000,Timeout),
169 retractall(max_simulation_length(_)),
170 get_animation_expression_nr(State,'GAME_MCTS_MAX_SIMULATION_LENGTH',100,MaxLengthOfSimulation),
171 assert(max_simulation_length(MaxLengthOfSimulation)),
172 (eval_animation_expression_in_state_id(State,'GAME_MCTS_CACHE_LAST_TREE',Res,not_mandatory)
173 -> (is_true(Res) -> Cache=cache ; Cache=no_cache)
174 ; Cache=cache),
175 mcts_auto_play(State,Cache,Timeout,SimRuns,Action,TransID,State2).
176
177 get_animation_expression_nr(State,DEFNAME,Default,Res) :-
178 (eval_animation_expression_in_state_id(State,DEFNAME,R,not_mandatory)
179 -> (get_number(R,Res), Res >= 0 -> true
180 ; add_warning(mcts_game_play,'GAME_MCTS_RUNS/GAME_MCTS_TIMEOUT should return positive integer: ',R),
181 Res=Default
182 )
183 ; Res = Default).
184
185 get_number(int(S),S).
186 get_number(term(floating(R)),R).
187 get_number(S,S) :- number(S).
188
189 :- dynamic saved_mcts_tree/2.
190 save_last_mcts_tree(cache,State2,FinalTree) :- !,
191 retractall(saved_mcts_tree(_,_)),
192 %gen_dot(FinalTree,2).
193 % tree for child can be obtained via: get_mcts_child_for_state(FinalTree,State2,NewTree),
194 assert(saved_mcts_tree(FinalTree,State2)).
195 save_last_mcts_tree(_,_,_).
196
197
198
199 mcts_auto_play(State,Cache,Timeout,SimRuns,Action,TransID,State2) :-
200 (retract(saved_mcts_tree(OldTree,State1)),
201 Cache=cache,
202 get_mcts_child_for_state(OldTree,State1,OldTree1), % TODO: check if this is not already the right tree; check if we do not need another move
203 get_mcts_child_for_state(OldTree1,State,InitialTree)
204 -> true %,print_tree(InitialTree,0,2)
205 ; InitialTree = leaf(State)),
206 mcts_incr_auto_play(State,Timeout,SimRuns,Action,TransID,State2,InitialTree,FinalTree),
207 save_last_mcts_tree(Cache,State2,FinalTree).
208
209 % a version where we can provide the initial MCTS Tree for incremental reuse
210 mcts_incr_auto_play(State,Timeout,SimRuns,Action,TransID,State2,Tree,FinalTree) :-
211 format('Starting MCTS AUTO PLAY, SimRuns=~w, Timeout=~w~n',[SimRuns,Timeout]),
212 statistics(walltime, [Start|_]),
213 (mcts_run(Timeout,SimRuns,Tree,FinalTree,Visits,State2)
214 -> statistics(walltime, [End|_]),
215 Delta is End - Start,
216 game_move(State,Action,TransID,State2),
217 format('Move found by MCTS in ~w ms (~w runs, ~w visits): ~w (~w --(~w)--> ~w)~n',[Delta,SimRuns,Visits,Action,State,TransID,State2])
218 ).
219
220
221 %mcts_run(Target) :- start(Init), mcts_run(5000,10000,leaf(Init),_,_,Target).
222
223 mcts_run(Timeout,Nr,Tree,FinalTree,Visits,Target) :-
224 statistics(walltime, [Cur|_]), EndTime is Cur+Timeout,
225 %set_prolog_flag(profiling,on),
226 mcts_loop(EndTime,Nr,Tree,FinalTree), %set_prolog_flag(profiling,off), print_profile,
227 get_best_mcts_move(FinalTree,Visits,Target),
228 get_node(FinalTree,From),
229 debug_format(19,'Best move from ~w is to ~w (~w visits)~n',[From,Target,Visits]),
230 (debug_mode(on) -> print_tree(FinalTree,0,3) ; true).
231
232
233 % run MCTS for a single initial tree with Nr iterations
234 mcts_loop(EndTime,Nr,Tree,FinalTree) :- Nr>1,!,
235 mcts(Tree,_,NewTree),
236 N1 is Nr-1,
237 (mcts_time_out(EndTime,N1,Tree) -> FinalTree=NewTree
238 ; mcts_loop(EndTime,N1,NewTree,FinalTree)).
239 mcts_loop(_,_,Tree,Tree). %format('Final MCTS Tree : ~w~n',[Tree]),
240
241 mcts_time_out(EndTime,Nr,_Tree) :- 0 is Nr mod 10,
242 statistics(walltime, [Cur|_]),
243 %format(' ~w -> ',[Nr]), print_tree_summary(_Tree),
244 Cur>EndTime,!,
245 format('MCTS TIME-OUT with ~w runs remaining.~n',[Nr]).
246
247
248 :- use_module(library(lists),[maplist/3, max_member/2, reverse/2]).
249
250 % find a direct child for a given (successor) state of the root state
251 % can be used to update the tree after a move was made
252 get_mcts_child_for_state(node(State,V,W,Children),State,Res) :- !,
253 % the node itself; we apply MCTS directly for other player
254 Res=node(State,V,W,Children).
255 get_mcts_child_for_state(node(_,_,_,Children),State,Child) :-
256 member(Child,Children),
257 get_node(Child,State),!.
258 get_mcts_child_for_state(Tree,State,Child) :-
259 print(cannot_get_child_for_state(State,Tree)),nl,
260 Child=leaf(State). % create a new root
261
262 get_best_mcts_move(node(_,_,_,Children),MaxV,Target) :-
263 maplist(get_visits,Children,Visits),
264 max_member(MaxV,Visits),
265 member(N,Children),
266 get_visits(N,MaxV),
267 get_node(N,Target).
268
269 invert_win(0,R) :- !, R=1.
270 invert_win(1,R) :- !, R=0.
271 invert_win(0.5,R) :- !, R=0.5.
272 invert_win(R,R1) :- R1 is 1-R.
273
274 %mcts(X,_,_) :- print(mcts(X)),nl,fail.
275 mcts(node(State,Wins,Visits,Childs),OuterWin,node(State,Wins1,V1,Childs1)) :-
276 V1 is Visits+1,
277 (Childs=[]
278 -> % the node has no children; simulate it, i.e., compute the utility value
279 Childs1=[],
280 simulate_for_parent(State,_,OuterWin,terminal)
281 ; LogNi is log(V1),
282 (select_best_ucb_child(Childs,State,LogNi,Child,Childs1,Child1) -> true
283 ; print(selection_failed),nl,trace),
284 mcts(Child,ChildWin,Child1),
285 invert_win(ChildWin,OuterWin)
286 ),
287 % backpropagate:
288 Wins1 is Wins+OuterWin.
289 %print(update(State,OuterWin,child(Child))),nl.
290 mcts(leaf(State),Wins,node(State,Wins,1,Childs)) :-
291 simulate_for_parent(State,_Val,Wins,Kind),
292 (Kind=terminal
293 -> Childs=[] % do not add any children; game is over already
294 ; winning_move(State,Child) -> Childs=[leaf(Child)] % minimax optimisation: pretend other children do not exist
295 ; findall(leaf(C),game_move(State,_,_,C),Childs)).
296 %print(expanded(State,Wins,Val,Childs)),nl.
297
298 select_best_ucb_child([C],_,_,C,[C1],C1) :- !. % no need to compute when there is a single child
299 select_best_ucb_child([Child1|Children],_Parent,LogNi,Child,NewChildren,NewChild) :- !,
300 ucb(Child1,LogNi,UCB1),
301 get_max_ucb(Children,LogNi,UCB1,Child1,[],Child,Rest),
302 NewChildren = [NewChild|Rest].
303
304 %select_best_ucb_child(Children,_Parent,LogNi,Child,NewChildren,NewChild) :- % old version with sort/maplist
305 % maplist(create_ucb_node(LogNi),Children,UC),
306 % sort(UC,UCS), reverse(UCS,RUCS), % this can be done more efficiently
307 % maplist(project,RUCS,SortedChildren),
308 % SortedChildren = [Child|Rest],
309 % NewChildren = [NewChild|Rest].
310
311 % select BestChild with maximal UCB value
312 get_max_ucb([],_,_,Child,Rest,Child,Rest).
313 get_max_ucb([Node|T],LogNi,CurMax,BestChildSoFar,RestSoFar,BestChild,Rest) :-
314 ucb(Node,LogNi,UCB),
315 UCB>CurMax, % we have a new best node
316 !,
317 get_max_ucb(T,LogNi,UCB,Node,[BestChildSoFar|RestSoFar],BestChild,Rest).
318 get_max_ucb([Node|T],LogNi,CurMax,BestChildSoFar,RestSoFar,BestChild,Rest) :-
319 get_max_ucb(T,LogNi,CurMax,BestChildSoFar,[Node|RestSoFar],BestChild,Rest).
320
321
322 % helper functions for maplist:
323 %project(ucb(_,Node),Node).
324 %create_ucb_node(LogNi,Node,ucb(UCB,Node)) :- ucb(Node,LogNi,UCB).
325 get_visits(node(_,_,V,_),V).
326 get_visits(leaf(_),0).
327 get_wins(node(_,W,_,_),W).
328 get_wins(leaf(_),0).
329 get_node(node(N,_,_,_),N).
330 get_node(leaf(N),N).
331 get_child(node(_,_,_,C),Child) :- nonvar(C),member(Child,C).
332
333 % compute UCB value for a node
334 ucb(leaf(_),_,Res) :- Res = 1000000.
335 ucb(node(_ID,Wins,Visits,C),LogNi,Res) :-
336 (Visits=0 -> Res = 10000
337 ; C=[], % terminal node
338 Wins=Visits % winning node for opponent:
339 % we assume player will always detect the winning move (limited mini-max improvement)
340 % will make a difference for tic-tac-toe, 40 SimRuns, [- - -, 0 x -, x - -] : with detection x blocked
341 -> Res = 10001 %,print(minimax_detection(_ID,Wins,Visits)),nl
342 ; Res is (Wins/Visits) + sqrt(2.0 * LogNi / Visits)
343 ).
344
345 :- dynamic max_simulation_length/1. % GAME_MCTS_MAX_SIMULATION_LENGTH
346
347 % simulate and report win as viewed by the parent of X
348 simulate_for_parent(NodeId,Val,Res,NodeKind) :-
349 max_simulation_length(MaxSimulationLength),
350 simulate_random(MaxSimulationLength,NodeId,Val,NodeKind),
351 (winning_node_for_parent_id(NodeId,Val) -> Res = 1
352 ; Val=0 -> Res = 0.5 % draw
353 ; Res = 0). % loss
354
355 winning_node_for_parent_id(NodeID,Val) :-
356 visited_expression(NodeID,State),
357 winning_node_for_parent(State,Val).
358 % check if value of some child node is winning for ancestor
359 winning_node_for_parent(State,Val) :-
360 get_player(State,MinMax),
361 (MinMax=max -> Val<0 % parent node is minimizing
362 ; Val>0). % parent node is maximimizing
363
364
365 :- use_module(library(random),[random_member/2, random_permutation/2]).
366 simulate_random(_,X,Res,terminal) :-
367 terminal(X,R,_),
368 !,Res=R.
369 simulate_random(Max,X,Res,non_terminal) :- Max>0, M1 is Max-1,
370 my_random_game_move(X,Z),!, % use random_game_move if not provided by game
371 simulate_random(M1,Z,Res,_).
372 simulate_random(_,_,Res,terminal) :-
373 Res = 0. % no moves possible or limit exceeded, we assume a draw
374 % TODO: call heuristic function if simulation was stopped
375
376
377 % -------------
378
379 :- use_module(probsrc(translate),[translate_bstate_limited/3]).
380 :- use_module(dotsrc(dot_graph_generator), [gen_dot_graph/6]).
381 :- use_module(probsrc(tools_strings),[ajoin/2]).
382 % DOT rendering
383
384 tcltk_gen_last_mcts_tree(MaxDepth,File) :-
385 saved_mcts_tree(Tree,_),!,
386 gen_dot_graph(File,[rankdir/'LR',no_page_size],mcts_node(Tree,MaxDepth),mcts_trans(Tree,MaxDepth),
387 dot_no_same_rank,dot_no_subgraph).
388 tcltk_gen_last_mcts_tree(_,_File) :-
389 add_error(mcts_game_play,'Run MCTS Game Play first and set GAME_MCTS_CACHE_LAST_TREE to TRUE',''),
390 fail.
391
392 mcts_tree_available :- saved_mcts_tree(_,_).
393
394 %gen_dot(Tree,MaxDepth) :- gen_dot_graph('mcts_tree.dot',[rankdir/'LR',no_page_size],mcts_node(Tree,MaxDepth),mcts_trans(Tree,MaxDepth),
395 % dot_no_same_rank,dot_no_subgraph).
396
397 :- use_module(probsrc(translate),[translate_event_with_limit/3]).
398 mcts_node(Tree,_,NodeID,none,NodeDesc,Shape,Style,Color) :-
399 get_node(Tree,NodeID),
400 get_wins(Tree,W), get_visits(Tree,N),
401 (specfile:b_mode,visited_expression(NodeID,State),translate_bstate_limited(State,50,TS)
402 -> true
403 ; xtl_mode -> TS=''
404 ; TS='??'),
405 (get_player_in_state_id(NodeID,Player) -> true ; Player='??'),
406 (terminal(NodeID,UVal,WinParent)
407 -> ajoin(['id (terminal): ',NodeID,', w=',W,',n=',N,'\\nplayer=',Player,
408 ', utility=',UVal,', ',WinParent,'\\n',TS],NodeDesc),
409 Style=bold,
410 (UVal>0 -> Color=green ; UVal<0 -> Color=red ; Color=blue)
411 ; current_state_id(NodeID)
412 -> ajoin(['id (current): ',NodeID,', w=',W,',n=',N,', player=',Player,'\\n',TS],NodeDesc),
413 Style=rounded,
414 Color=black
415 ; ajoin(['id: ',NodeID,', w=',W,',n=',N,', player=',Player,'\\n',TS],NodeDesc),
416 Style=rounded,
417 Color=gray
418 ),
419 Shape=rectangle.
420 mcts_node(Tree,MaxDepth,NodeID,none,NodeDesc,Shape,Style,Color) :-
421 MaxDepth>=1, MD is MaxDepth-1,
422 get_child(Tree,Child),
423 mcts_node(Child,MD,NodeID,none,NodeDesc,Shape,Style,Color).
424 mcts_trans(Tree,_,NodeID,Label,SuccID,Color,Style) :-
425 get_node(Tree,NodeID),
426 (get_best_mcts_move(Tree,_,Best) -> true ; Best=unknown),
427 get_child(Tree,Child),
428 get_node(Child,SuccID),
429 (SuccID=Best -> Style=bold ; Style=solid),
430 (current_state_id(SuccID) -> Color=black ; Color=lightgray),
431 (transition(NodeID,Ev,_,SuccID) -> translate_event_with_limit(Ev,30,Label) ; Label='move').
432 mcts_trans(Tree,MaxDepth,NodeID,Label,SuccID,Color,Style) :-
433 MaxDepth>1, MD is MaxDepth-1,
434 get_child(Tree,Child),
435 mcts_trans(Child,MD,NodeID,Label,SuccID,Color,Style).
436
437
438 % -------------
439
440 pretty_print_node(StateID) :- format('State: ~w',[StateID]).
441
442 print_tree(leaf(Node),Indent,_) :- indent(Indent), pretty_print_node(Node),nl.
443 print_tree(node(Node,Wins,Visits,Children),Indent,Max) :- Max>0, M1 is Max-1,
444 indent(Indent), pretty_print_node(Node),nl,
445 length(Children,Childs),
446 indent(Indent), format(' w=~w, n=~w, childs=~w~n',[Wins,Visits,Childs]),
447 (member(C,Children), print_tree(C,s(Indent),M1), fail
448 ; true).
449 indent(0).
450 indent(s(X)) :- print(' + '), indent(X).
451
452 print_tree_summary(leaf(Node)) :- format('Tree is leaf for ~w~n',[Node]).
453 print_tree_summary(node(Node,Wins,Visits,_Children)) :- format('Tree for ~w: ~w wins, ~w visits~n',[Node,Wins,Visits]).