package aima.logic.fol;

import aima.logic.fol.parsing.ast.FOLNode;
import aima.logic.fol.parsing.ast.Function;
import aima.logic.fol.parsing.ast.Term;
import aima.logic.fol.parsing.ast.Variable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:aima/logic/fol/Unifier.class */
public class Unifier {
    private static SubstVisitor _substVisitor = new SubstVisitor();
    private static VariableCollector _variableCollector = new VariableCollector();

    public Map<Variable, Term> unify(FOLNode fOLNode, FOLNode fOLNode2) {
        return unify(fOLNode, fOLNode2, new LinkedHashMap());
    }

    public Map<Variable, Term> unify(FOLNode fOLNode, FOLNode fOLNode2, Map<Variable, Term> map) {
        if (map == null) {
            return null;
        }
        if (fOLNode.equals(fOLNode2)) {
            return map;
        }
        if (fOLNode instanceof Variable) {
            return unifyVar((Variable) fOLNode, fOLNode2, map);
        }
        if (fOLNode2 instanceof Variable) {
            return unifyVar((Variable) fOLNode2, fOLNode, map);
        }
        if (isCompound(fOLNode) && isCompound(fOLNode2)) {
            return unify(args(fOLNode), args(fOLNode2), unifyOps(op(fOLNode), op(fOLNode2), map));
        }
        return null;
    }

    public Map<Variable, Term> unify(List<? extends FOLNode> list, List<? extends FOLNode> list2, Map<Variable, Term> map) {
        if (map != null && list.size() == list2.size()) {
            return (list.size() == 0 && list2.size() == 0) ? map : (list.size() == 1 && list2.size() == 1) ? unify(list.get(0), list2.get(0), map) : unify(list.subList(1, list.size()), list2.subList(1, list2.size()), unify(list.get(0), list2.get(0), map));
        }
        return null;
    }

    protected boolean occurCheck(Map<Variable, Term> map, Variable variable, FOLNode fOLNode) {
        if (!(fOLNode instanceof Function)) {
            return false;
        }
        Set<Variable> collectAllVariables = _variableCollector.collectAllVariables((Function) fOLNode);
        if (collectAllVariables.contains(variable)) {
            return true;
        }
        return cascadeOccurCheck(map, variable, collectAllVariables, new HashSet(collectAllVariables));
    }

    private Map<Variable, Term> unifyVar(Variable variable, FOLNode fOLNode, Map<Variable, Term> map) {
        if (!Term.class.isInstance(fOLNode)) {
            return null;
        }
        if (map.keySet().contains(variable)) {
            return unify(map.get(variable), fOLNode, map);
        }
        if (map.keySet().contains(fOLNode)) {
            return unify(variable, map.get(fOLNode), map);
        }
        if (occurCheck(map, variable, fOLNode)) {
            return null;
        }
        cascadeSubstitution(map, variable, (Term) fOLNode);
        return map;
    }

    private Map<Variable, Term> unifyOps(String str, String str2, Map<Variable, Term> map) {
        if (map != null && str.equals(str2)) {
            return map;
        }
        return null;
    }

    private List<? extends FOLNode> args(FOLNode fOLNode) {
        return fOLNode.getArgs();
    }

    private String op(FOLNode fOLNode) {
        return fOLNode.getSymbolicName();
    }

    private boolean isCompound(FOLNode fOLNode) {
        return fOLNode.isCompound();
    }

    private boolean cascadeOccurCheck(Map<Variable, Term> map, Variable variable, Set<Variable> set, Set<Variable> set2) {
        HashSet hashSet = new HashSet();
        Iterator<Variable> it = set.iterator();
        while (it.hasNext()) {
            Term term = map.get(it.next());
            if (null != term) {
                if (term.equals(variable)) {
                    return true;
                }
                if (term instanceof Function) {
                    Set<Variable> collectAllVariables = _variableCollector.collectAllVariables((Function) term);
                    if (collectAllVariables.contains(variable)) {
                        return true;
                    }
                    for (Variable variable2 : collectAllVariables) {
                        if (!set2.contains(variable2)) {
                            hashSet.add(variable2);
                        }
                    }
                } else {
                    continue;
                }
            }
        }
        if (hashSet.size() <= 0) {
            return false;
        }
        set2.addAll(hashSet);
        return cascadeOccurCheck(map, variable, hashSet, set2);
    }

    private void cascadeSubstitution(Map<Variable, Term> map, Variable variable, Term term) {
        map.put(variable, term);
        for (Variable variable2 : map.keySet()) {
            map.put(variable2, _substVisitor.subst(map, map.get(variable2)));
        }
    }
}
