package org.eventb.internal.pp.core.provers.equality.unionfind;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.eventb.internal.pp.core.Level;
import org.eventb.internal.pp.core.provers.equality.unionfind.Source;

/* loaded from: input_file:org/eventb/internal/pp/core/provers/equality/unionfind/EqualitySolver.class */
public final class EqualitySolver {
    private final SourceTable sourceTable;
    private final Set<Node> nodes = new HashSet();
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !EqualitySolver.class.desiredAssertionStatus();
    }

    public EqualitySolver(SourceTable sourceTable) {
        this.sourceTable = sourceTable;
    }

    private void addNode(Node node) {
        this.nodes.add(node);
    }

    public FactResult addFactEquality(Equality<Source.FactSource> equality) {
        Node left = equality.getLeft();
        Node right = equality.getRight();
        addNode(left);
        addNode(right);
        if (!$assertionsDisabled && left.compareTo(right) >= 0) {
            throw new AssertionError();
        }
        this.sourceTable.addSource(left, right, equality.getSource());
        HashSet hashSet = new HashSet();
        Node find = find(left, hashSet);
        HashSet hashSet2 = new HashSet();
        Node find2 = find(right, hashSet2);
        if (find == find2) {
            return null;
        }
        HashSet hashSet3 = new HashSet();
        hashSet3.addAll(hashSet);
        hashSet3.addAll(hashSet2);
        hashSet3.add(equality.getSource());
        union(find, find2, hashSet3);
        Set<Source.FactSource> checkContradiction = checkContradiction(find);
        if (checkContradiction != null) {
            return new FactResult(checkContradiction);
        }
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(checkQueryContradiction(find, true));
        arrayList.addAll(checkQueryContradiction(find, false));
        if (arrayList.isEmpty()) {
            return null;
        }
        return new FactResult((List<QueryResult>) arrayList, false);
    }

    private List<QueryResult> checkQueryContradiction(Node node, boolean z) {
        ArrayList arrayList = new ArrayList();
        for (RootInfo<Source.QuerySource> rootInfo : z ? node.getRootQueryEqualities() : node.getRootQueryInequalities()) {
            Equality<Source.QuerySource> equality = rootInfo.getEquality();
            if (equality.getSource().isValid()) {
                if (rootInfo.updateAndGetInequalNode() == node) {
                    arrayList.add(new QueryResult(equality.getSource(), source(equality.getLeft(), equality.getRight()), z));
                }
            } else if (z) {
                node.removeRootQueryEquality(equality);
            } else {
                node.removeRootQueryInequality(equality);
            }
        }
        return arrayList;
    }

    private Set<Source.FactSource> checkContradiction(Node node) {
        Set<Source.FactSource> set = null;
        Level level = null;
        for (RootInfo<Source.FactSource> rootInfo : node.getRootFactsInequalities()) {
            Equality<Source.FactSource> equality = rootInfo.getEquality();
            if (rootInfo.updateAndGetInequalNode() == node) {
                Set<Source.FactSource> source = source(equality.getLeft(), equality.getRight());
                source.add(equality.getSource());
                Level level2 = Source.getLevel(source);
                if (level == null || level2.isAncestorOf(level)) {
                    set = source;
                    level = level2;
                }
            }
        }
        return set;
    }

    public FactResult addFactInequality(Equality<Source.FactSource> equality) {
        Node left = equality.getLeft();
        Node right = equality.getRight();
        addNode(left);
        addNode(right);
        if (!$assertionsDisabled && left.compareTo(right) >= 0) {
            throw new AssertionError();
        }
        HashSet hashSet = new HashSet();
        Node find = find(left, hashSet);
        HashSet hashSet2 = new HashSet();
        Node find2 = find(right, hashSet2);
        if (find == find2) {
            Set exclusiveOr = exclusiveOr(hashSet, hashSet2);
            exclusiveOr.add(equality.getSource());
            return new FactResult((Set<Source.FactSource>) exclusiveOr);
        }
        find.addRootFactInequality(new RootInfo<>(find2, equality));
        find2.addRootFactInequality(new RootInfo<>(find, equality));
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(checkInequalityContradictionWithQuery(find, find2, left, right, equality.getSource(), false));
        arrayList.addAll(checkInequalityContradictionWithQuery(find, find2, left, right, equality.getSource(), true));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(checkInequalityInstantiation(find, find2, left, right, equality.getSource()));
        arrayList2.addAll(checkInequalityInstantiation(find2, find, right, left, equality.getSource()));
        if (!arrayList.isEmpty() && !arrayList2.isEmpty()) {
            return new FactResult(arrayList, arrayList2);
        }
        if (!arrayList.isEmpty()) {
            return new FactResult((List<QueryResult>) arrayList, false);
        }
        if (arrayList2.isEmpty()) {
            return null;
        }
        return new FactResult(arrayList2);
    }

    private List<InstantiationResult> checkInequalityInstantiation(Node node, Node node2, Node node3, Node node4, Source.FactSource factSource) {
        ArrayList arrayList = new ArrayList();
        for (Instantiation instantiation : node.getRootInstantiations()) {
            Node find = find(instantiation.getNode(), new HashSet());
            Node node5 = null;
            Node node6 = null;
            Node node7 = null;
            if (find == node) {
                node5 = node4;
                node6 = node3;
                node7 = node2;
            } else if (find == node2) {
                node5 = node3;
                node6 = node4;
                node7 = node;
            } else if (!$assertionsDisabled) {
                throw new AssertionError();
            }
            if (!instantiation.hasInstantiation(node7)) {
                Set<Source.FactSource> source = source(node6, instantiation.getNode());
                source.add(factSource);
                InstantiationResult instantiationResult = new InstantiationResult(node5, instantiation.getSource(), source);
                instantiation.saveInstantiation(instantiationResult.getLevel(), node5);
                arrayList.add(instantiationResult);
            }
        }
        return arrayList;
    }

    private List<QueryResult> checkInequalityContradictionWithQuery(Node node, Node node2, Node node3, Node node4, Source.FactSource factSource, boolean z) {
        ArrayList arrayList = new ArrayList();
        for (RootInfo<Source.QuerySource> rootInfo : z ? node.getRootQueryInequalities() : node.getRootQueryEqualities()) {
            Equality<Source.QuerySource> equality = rootInfo.getEquality();
            if (equality.getSource().isValid()) {
                if (rootInfo.updateAndGetInequalNode() == node2) {
                    Set<Source.FactSource> contradictionSourceInTwoTrees = getContradictionSourceInTwoTrees(node3, node4, equality.getLeft(), equality.getRight());
                    contradictionSourceInTwoTrees.add(factSource);
                    arrayList.add(new QueryResult(equality.getSource(), contradictionSourceInTwoTrees, z));
                }
            } else if (z) {
                node.removeRootQueryInequality(equality);
            } else {
                node.removeRootQueryEquality(equality);
            }
        }
        return arrayList;
    }

    private Set<Source.FactSource> getContradictionSourceInTwoTrees(Node node, Node node2, Node node3, Node node4) {
        HashSet hashSet = new HashSet();
        Node find = find(node, hashSet);
        HashSet hashSet2 = new HashSet();
        Node find2 = find(node2, hashSet2);
        HashSet hashSet3 = new HashSet();
        Node find3 = find(node3, hashSet3);
        HashSet hashSet4 = new HashSet();
        Node find4 = find(node4, hashSet4);
        Set set = null;
        Set set2 = null;
        if (find == find3) {
            if (!$assertionsDisabled && find2 != find4) {
                throw new AssertionError();
            }
            set = exclusiveOr(hashSet, hashSet3);
            set2 = exclusiveOr(hashSet2, hashSet4);
        } else if (find == find4) {
            if (!$assertionsDisabled && find2 != find3) {
                throw new AssertionError();
            }
            set = exclusiveOr(hashSet, hashSet4);
            set2 = exclusiveOr(hashSet2, hashSet3);
        } else if (!$assertionsDisabled) {
            throw new AssertionError();
        }
        return exclusiveOr(set, set2);
    }

    public QueryResult addQuery(Equality<Source.QuerySource> equality, boolean z) {
        Node left = equality.getLeft();
        Node right = equality.getRight();
        addNode(left);
        addNode(right);
        if (!$assertionsDisabled && left.compareTo(right) >= 0) {
            throw new AssertionError("wrong order: " + left + "=" + right);
        }
        HashSet hashSet = new HashSet();
        Node find = find(left, hashSet);
        HashSet hashSet2 = new HashSet();
        Node find2 = find(right, hashSet2);
        if (find == find2) {
            return new QueryResult(equality.getSource(), exclusiveOr(hashSet, hashSet2), z);
        }
        QueryResult checkQueryContradictionWithInequality = checkQueryContradictionWithInequality(left, right, find, find2, equality.getSource(), !z);
        if (checkQueryContradictionWithInequality != null) {
            return checkQueryContradictionWithInequality;
        }
        QueryResult checkQueryContradictionWithInequality2 = checkQueryContradictionWithInequality(right, left, find2, find, equality.getSource(), !z);
        if (checkQueryContradictionWithInequality2 != null) {
            return checkQueryContradictionWithInequality2;
        }
        addRootInfo(equality, z, find2, find);
        addRootInfo(equality, z, find, find2);
        return null;
    }

    private void addRootInfo(Equality<Source.QuerySource> equality, boolean z, Node node, Node node2) {
        RootInfo<Source.QuerySource> rootInfo = new RootInfo<>(node, equality);
        if (z) {
            node2.addRootQueryEquality(rootInfo);
        } else {
            node2.addRootQueryInequality(rootInfo);
        }
    }

    private QueryResult checkQueryContradictionWithInequality(Node node, Node node2, Node node3, Node node4, Source.QuerySource querySource, boolean z) {
        for (RootInfo<Source.FactSource> rootInfo : node3.getRootFactsInequalities()) {
            if (rootInfo.updateAndGetInequalNode() == node4) {
                Set<Source.FactSource> contradictionSourceInTwoTrees = getContradictionSourceInTwoTrees(node, node2, rootInfo.getEquality().getLeft(), rootInfo.getEquality().getRight());
                contradictionSourceInTwoTrees.add(rootInfo.getEquality().getSource());
                return new QueryResult(querySource, contradictionSourceInTwoTrees, z);
            }
        }
        return null;
    }

    public List<InstantiationResult> addInstantiation(Instantiation instantiation) {
        Node node = instantiation.getNode();
        Node find = find(node, new HashSet());
        HashMap hashMap = new HashMap();
        for (RootInfo<Source.FactSource> rootInfo : find.getRootFactsInequalities()) {
            if (!instantiation.hasInstantiation(rootInfo.updateAndGetInequalNode())) {
                Equality<Source.FactSource> equality = rootInfo.getEquality();
                Node find2 = find(equality.getLeft(), new HashSet());
                Node find3 = find(equality.getRight(), new HashSet());
                Node node2 = null;
                Node node3 = null;
                Node node4 = null;
                if (find == find3) {
                    node3 = equality.getRight();
                    node2 = equality.getLeft();
                    node4 = find2;
                } else if (find == find2) {
                    node3 = equality.getLeft();
                    node2 = equality.getRight();
                    node4 = find3;
                } else if (!$assertionsDisabled) {
                    throw new AssertionError();
                }
                Set<Source.FactSource> source = source(node3, node);
                source.add(equality.getSource());
                InstantiationResult instantiationResult = new InstantiationResult(node2, instantiation.getSource(), source);
                if (hashMap.containsKey(node4)) {
                    if (instantiationResult.getLevel().isAncestorOf(((InstantiationResult) hashMap.get(node4)).getLevel())) {
                        hashMap.put(node4, instantiationResult);
                    }
                } else {
                    hashMap.put(node4, instantiationResult);
                }
            }
        }
        for (Map.Entry entry : hashMap.entrySet()) {
            instantiation.saveInstantiation(((InstantiationResult) entry.getValue()).getLevel(), (Node) entry.getKey());
        }
        find.addRootInstantiation(instantiation);
        if (hashMap.isEmpty()) {
            return null;
        }
        return new ArrayList(hashMap.values());
    }

    private void union(Node node, Node node2, Set<Source.FactSource> set) {
        node2.setParent(node);
        this.sourceTable.addSource(node, node2, set);
        node.addRootFactsInequalities(node2.getRootFactsInequalities());
        node.addRootQueryEqualities(node2.getRootQueryEqualities());
        node.addRootQueryInequalities(node2.getRootQueryInequalities());
        node2.deleteRootInfos();
    }

    private Node find(Node node, Set<Source.FactSource> set) {
        if (node.isRoot()) {
            return node;
        }
        HashSet hashSet = new HashSet();
        Node find = find(node.getParent(), hashSet);
        this.sourceTable.addSource(node, find, exclusiveOr(hashSet, this.sourceTable.getSource(node, node.getParent())));
        set.addAll(this.sourceTable.getSource(node, find));
        if (!node.getParent().isRoot()) {
            node.setParent(find);
            this.sourceTable.addSource(node, find, set);
        }
        return find;
    }

    private Set<Source.FactSource> source(Node node, Node node2) {
        HashSet hashSet = new HashSet();
        find(node, hashSet);
        HashSet hashSet2 = new HashSet();
        find(node2, hashSet2);
        return exclusiveOr(hashSet, hashSet2);
    }

    private <T extends Source> Set<T> exclusiveOr(Set<T> set, Set<T> set2) {
        HashSet hashSet = new HashSet();
        for (T t : set) {
            if (!set2.contains(t)) {
                hashSet.add(t);
            }
        }
        for (T t2 : set2) {
            if (!set.contains(t2)) {
                hashSet.add(t2);
            }
        }
        return hashSet;
    }

    public Set<String> dump() {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (Node node : this.nodes) {
            if (!node.isRoot()) {
                linkedHashSet.add(String.valueOf(node.toString()) + "->" + node.getParent().toString());
            }
            if (!node.getRootFactsInequalities().isEmpty()) {
                Iterator<RootInfo<Source.FactSource>> it = node.getRootFactsInequalities().iterator();
                while (it.hasNext()) {
                    linkedHashSet.add(String.valueOf(node.toString()) + "[F, ≠" + it.next().updateAndGetInequalNode() + "]");
                }
            }
            if (!node.getRootQueryEqualities().isEmpty()) {
                Iterator<RootInfo<Source.QuerySource>> it2 = node.getRootQueryEqualities().iterator();
                while (it2.hasNext()) {
                    linkedHashSet.add(String.valueOf(node.toString()) + "[Q, =" + it2.next().updateAndGetInequalNode() + "]");
                }
            }
            if (!node.getRootQueryInequalities().isEmpty()) {
                Iterator<RootInfo<Source.QuerySource>> it3 = node.getRootQueryInequalities().iterator();
                while (it3.hasNext()) {
                    linkedHashSet.add(String.valueOf(node.toString()) + "[Q, ≠" + it3.next().updateAndGetInequalNode() + "]");
                }
            }
        }
        return linkedHashSet;
    }

    public String toString() {
        return dump().toString();
    }

    public SourceTable getSourceTable() {
        return this.sourceTable;
    }
}
