W = [[-0.197, 0.093], [-0.140, 0.161], [ 0.116, -0.133], [ 0.078, -0.001], [-0.127, 0.065]]
States = [0, 1, 2, 3] Actions = [1, 1, 1, 1] Rewards = [-0.1, -0.1, -0.1, 1.0] Gamma = 0.95
G3 = 1.0 G2 = -0.1 + 0.95 * 1.0 = -0.1 + 0.95 = 0.85 G1 = -0.1 + 0.95 * 0.85 = -0.1 + 0.8075 = 0.7075 G0 = -0.1 + 0.95 * 0.7075 = -0.1 + 0.672125 = 0.572125
Mean = (0.572125 + 0.7075 + 0.85 + 1.0) / 4
= 3.129625 / 4
= 0.78240625
Variance =
((0.572125 - 0.78240625)^2 +
(0.7075 - 0.78240625)^2 +
(0.85 - 0.78240625)^2 +
(1.0 - 0.78240625)^2) / 4
= (0.0442 + 0.0056 + 0.0046 + 0.0473) / 4
= 0.1017 / 4
= 0.0254
Std = sqrt(0.0254) = 0.159
Normalized:
G0 = (0.572125 - 0.78240625) / 0.159 = -1.32
G1 = (0.7075 - 0.78240625) / 0.159 = -0.47
G2 = (0.85 - 0.78240625) / 0.159 = 0.42
G3 = (1.0 - 0.78240625) / 0.159 = 1.37
x = [1,0,0,0,0] logits[0] = 1*(-0.197) + 0 + 0 + 0 + 0 = -0.197 logits[1] = 1*(0.093) + 0 + 0 + 0 + 0 = 0.093
max = 0.093
shifted = [-0.197 - 0.093, 0.093 - 0.093]
= [-0.29, 0]
exp = [e^-0.29, e^0]
= [0.748, 1]
sum = 1.748
probs = [0.748/1.748, 1/1.748]
= [0.428, 0.572]
Chosen action = 1 d_logits[0] = -0.428 d_logits[1] = 1 - 0.572 = 0.428
dW row 0 = [-0.428 * -1.32, 0.428 * -1.32] = [0.565, -0.565]---
x = [0,1,0,0,0] logits = [-0.14, 0.161] shifted = [-0.301, 0] exp = [0.740, 1] sum = 1.740 probs = [0.425, 0.575] d_logits = [-0.425, 0.425] G1 = -0.47 dW row 1 = [-0.425 * -0.47, 0.425 * -0.47] = [0.199, -0.199]---
x = [0,0,1,0,0] logits = [0.116, -0.133] shifted = [0, -0.249] exp = [1, 0.779] sum = 1.779 probs = [0.562, 0.438] d_logits = [-0.562, 0.562] G2 = 0.42 dW row 2 = [-0.562 * 0.42, 0.562 * 0.42] = [-0.236, 0.236]---
x = [0,0,0,1,0] logits = [0.078, -0.001] shifted = [0, -0.079] exp = [1, 0.924] sum = 1.924 probs = [0.520, 0.480] d_logits = [-0.520, 0.520] G3 = 1.37 dW row 3 = [-0.520 * 1.37, 0.520 * 1.37] = [-0.712, 0.712]---
dW = [[ 0.565, -0.565], [ 0.199, -0.199], [-0.236, 0.236], [-0.712, 0.712], [ 0, 0 ]]---
W_new = W + 0.01 * dW Row 0: [-0.197 + 0.00565, 0.093 - 0.00565] = [-0.19135, 0.08735] Row 1: [-0.14 + 0.00199, 0.161 - 0.00199] = [-0.13801, 0.15901] Row 2: [0.116 - 0.00236, -0.133 + 0.00236] = [0.11364, -0.13064] Row 3: [0.078 - 0.00712, -0.001 + 0.00712] = [0.07088, 0.00612] Row 4 unchanged---
W_new = [[-0.191, 0.087], [-0.138, 0.159], [ 0.114, -0.131], [ 0.071, 0.006], [-0.127, 0.065]]---
Each state only updates its own row in W.
If return G is positive → increase probability of chosen action.
If return G is negative → decrease it.
This is how the agent "learns" from experience.