package aima.learning.reinforcement;

import aima.probability.decision.MDP;
import aima.probability.decision.MDPPerception;
import aima.util.FrequencyCounter;
import aima.util.Pair;
import java.util.Hashtable;

/* loaded from: input_file:aima/learning/reinforcement/QLearningAgent.class */
public class QLearningAgent<STATE_TYPE, ACTION_TYPE> extends MDPAgent<STATE_TYPE, ACTION_TYPE> {
    private Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> Q;
    private FrequencyCounter<Pair<STATE_TYPE, ACTION_TYPE>> stateActionCount;
    private Double previousReward;
    private QTable<STATE_TYPE, ACTION_TYPE> qTable;
    private int actionCounter;

    public QLearningAgent(MDP<STATE_TYPE, ACTION_TYPE> mdp) {
        super(mdp);
        this.Q = new Hashtable<>();
        this.qTable = new QTable<>(mdp.getAllActions());
        this.stateActionCount = new FrequencyCounter<>();
        this.actionCounter = 0;
    }

    @Override // aima.learning.reinforcement.MDPAgent
    public ACTION_TYPE decideAction(MDPPerception<STATE_TYPE> mDPPerception) {
        this.currentState = mDPPerception.getState();
        this.currentReward = Double.valueOf(mDPPerception.getReward());
        if (startingTrial()) {
            updateLearnerState(selectRandomAction());
            return this.previousAction;
        }
        if (!this.mdp.isTerminalState(this.currentState)) {
            incrementStateActionCount(this.previousState, this.previousAction);
            updateLearnerState(updateQ(0.8d));
            return this.previousAction;
        }
        incrementStateActionCount(this.previousState, this.previousAction);
        updateQ(0.8d);
        this.previousAction = null;
        this.previousState = null;
        this.previousReward = null;
        return this.previousAction;
    }

    private void updateLearnerState(ACTION_TYPE action_type) {
        this.previousAction = action_type;
        this.previousAction = action_type;
        this.previousState = this.currentState;
        this.previousReward = this.currentReward;
    }

    private ACTION_TYPE updateQ(double d) {
        this.actionCounter++;
        return this.qTable.upDateQ(this.previousState, this.previousAction, this.currentState, calculateProbabilityOf(this.previousState, this.previousAction), this.currentReward.doubleValue(), 0.8d);
    }

    private double calculateProbabilityOf(STATE_TYPE state_type, ACTION_TYPE action_type) {
        Double valueOf = Double.valueOf(0.0d);
        Double valueOf2 = Double.valueOf(0.0d);
        for (Pair<STATE_TYPE, ACTION_TYPE> pair : this.stateActionCount.getStates()) {
            if (pair.getFirst().equals(state_type)) {
                valueOf = Double.valueOf(valueOf.doubleValue() + 1.0d);
                if (pair.getSecond().equals(action_type)) {
                    valueOf2 = Double.valueOf(valueOf2.doubleValue() + 1.0d);
                }
            }
        }
        return valueOf2.doubleValue() / valueOf.doubleValue();
    }

    private ACTION_TYPE actionMaximizingLearningFunction() {
        ACTION_TYPE action_type = null;
        Double valueOf = Double.valueOf(Double.NEGATIVE_INFINITY);
        for (ACTION_TYPE action_type2 : this.mdp.getAllActions()) {
            Double learningFunction = learningFunction(this.qTable.getQValue(this.currentState, action_type2));
            if (learningFunction.doubleValue() > valueOf.doubleValue()) {
                valueOf = learningFunction;
                action_type = action_type2;
            }
        }
        return action_type;
    }

    private Double learningFunction(Double d) {
        if (this.actionCounter <= 3) {
            return d;
        }
        this.actionCounter = 0;
        return Double.valueOf(1.0d);
    }

    private ACTION_TYPE selectRandomAction() {
        return this.mdp.getAllActions().get(0);
    }

    private boolean startingTrial() {
        return this.previousAction == null && this.previousState == null && this.previousReward == null && this.currentState.equals(this.mdp.getInitialState());
    }

    private void incrementStateActionCount(STATE_TYPE state_type, ACTION_TYPE action_type) {
        this.stateActionCount.incrementFor(new Pair<>(state_type, action_type));
    }

    public Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> getQ() {
        return this.Q;
    }

    public QTable<STATE_TYPE, ACTION_TYPE> getQTable() {
        return this.qTable;
    }
}
