#!/usr/bin/env python3

#
# Q learning for walker (not blind)
#
# based on
# https://www.geeksforgeeks.org/q-learning-in-python/

import numpy as np

# environment
n_states   = 6   # number of states in a linear world
n_actions  = 2   # number of possible moves (distances)
goal_state = 5   # game ends with a win
loss_state = 3   # game ends with a loss

# initialize Q-table with zeros
Q_table = np.zeros((n_states, n_actions))

# parameters
learning_rate    = 0.8
discount_factor  = 0.95
exploration_prob = 0.2
nEpochs = 5000

# allowed starting positions
start_pos = set(range(n_states))    # range to set
start_pos.remove(loss_state)        # remove loss/win states
start_pos.remove(goal_state)

#
# Q-learning from random starting state
#
for epoch in range(nEpochs):
  current_state = np.random.choice(list(start_pos))
  game_length = 0                                 # currently not used
  while (current_state!=goal_state) and \
        (current_state!=loss_state):
    game_length += 1
#
    if np.random.rand() < exploration_prob:       # epsilon-greedy strategy
       action = np.random.randint(0, n_actions)   # exploring
    else:
       action = np.argmax(Q_table[current_state]) # exploiting

# determine next state, periodic boundaries
    next_state = (current_state + 1 + action)%n_states

# reward is \pm 1 for goal/loss state, 0 otherwise
    reward = 0.0
    if (next_state==goal_state): 
      reward =  1.0
    if (next_state==loss_state): 
      reward = -1.0

# updating Q-table 
    Q_table[current_state, action] += learning_rate * \
           ( reward + discount_factor*np.max(Q_table[next_state]) - 
             Q_table[current_state, action] )

    current_state = next_state                    # move to next state

#
# printing
# 
print("expected dicounted rewards Q(S,A)")
print("state,  Q-table")
for iS in range(n_states):
  print(f'{iS:5d} ', end = "")
  for iA in range(n_actions):
    QQ = Q_table[iS][iA]
    print(f'{QQ:7.3f} ', end = "")
#
  if (iS==loss_state):
    print(" loss", end = "")
  if (iS==goal_state):
    print(" win", end = "")
  print()
