Motivation:

Let’s suppose we want to approximate a controller for a system that has a discrete action space. This controller may be formally defined as a conditional distribution , parametrised by a function approximator, which computes the probability of a particular action given the current state. So given a state , the output of is a vector of probabilities:

Now, assuming that we use a multinomial distribution parametrised by to sample from actions we might want to know the log-probability of the action taken in order to update the controller using a Policy-Gradient method. The challenge however is that if we want to use gradient-based methods to update we can’t backpropagate through stochastic nodes.

The max operator:

If we’re in a quasi-deterministic setting, the optimal policy would behave in a deterministic manner and choose the most probable action at each instant so the log-probability of the chosen action may be approximated by:

The problem with the operator is that it’s not differentiable and so we wouldn’t be able to use backpropagation to update the controller in question. But, perhaps we can find good differentiable approximations to the operator.

Approximations to the max operator:

Considering that all the elements of are non-negative, we have:

While the infinity norm isn’t a differentiable operator, we aren’t far from what we need. If we consider that:

we may deduce that for sufficiently large:

Indeed, if we run the following numpy script we find that the quality of the approximation reaches an asymptote very quickly even for small values of :

import numpy as np

def worst_approx(N):
	"""
	This function outputs the worst quality approximation out of 100 runs 
	for a given N.
	"""

    quality = np.zeros(100)
    
    for i in range(100):
        
        R = np.random.rand(10)
    
    	## normalise the vector
        P = R/np.sum(R) 
        
        ## calculate the maximum 
        max_P = np.max(P) 
        
        ## use an approximation to the infinity norm:
        max_approx = np.sum(P**N)**(1/N)
        
        vals = [max_P,max_approx]

        ## quality of approximation
        quality[i] = np.min(vals)/np.max(vals) 
      
    return np.min(quality)

Q_min = np.zeros(11)

for i in range(11):
    
    Q_min[i] = worst_approx(i+5)

Analysis:

In order to determine whether this approximation is useful for propagating gradients it’s sufficient to calculate the partial derivative with respect to a particular weight matrix:

and we have equality in the case that and or when which suggests that although this approximation may allow propagation of gradients it isn’t particularly good for rewarding efficient exploration. Moreover, this approximation is strictly valid if at each instant our controller behaved in a deterministic manner by choosing the most probable action rather than sampling actions using a multinomial distribution with parameters .