KLダイバージェンスの非対称性

KLダイバージェンスは、2つの確率分布間の「距離」を測る手法ですが、その距離は対称的ではありません。つまり、分布 P から分布 Q へのKLダイバージェンス DKL(PQ) と、分布 Q から分布 P へのKLダイバージェンス DKL(QP) は通常異なる値になります。

数式で示すと以下のようになります:

$$ D_{\text{KL}}(P \| Q) = \sum_{x} P(x) \log \frac{P(x)}{Q(x)} \\ D_{\text{KL}}(Q \| P) = \sum_{x} Q(x) \log \frac{Q(x)}{P(x)}

$$

この非対称性は、どちらの分布を基準にするかで異なる情報を提供するため、特定の問題において片方のダイバージェンスを選ぶ理由になります。たとえば、モード探索や最適化の際に、どちらのKLダイバージェンスを最小化するかによって得られる解が異なることがあります。

実装

Google Colab

  def update(self,x0: np.ndarray) -> Tuple[float, np.ndarray]:
    u = self.u_prev
    # 現在の車両位置に最も近いウェイポイントを取得
    a = self.get_nearest_waypoint(x0[0], x0[1], update_prev_idx=True)
    # 状態コストリスト
    S = np.zeros((self.K))
    # ノイズをサンプリング
    epsilon = self.calc_epsilon(self.Sigma, self.K, self.T, self.dim_u) # size is self.K x self.T
    # サンプリングされた制御入力シーケンスのバッファを準備
    v = np.zeros((self.K, self.T, self.dim_u))  # ノイズ付きの制御入力系列
    # 0 ~ K-1 のサンプルのループ
    for k in range(self.K):
      # 初期(t=0)の状態 x、つまり車両の観測状態を設定
      x = x0
      # 時間ステップ t = 1 ~ T のループ
      for t in range(1, self.T + 1):
        # ノイズ付きの制御入力を取得
        if k < (1.0 - self.param_exploration) * self.K:
          v[k, t - 1] = u[t - 1] + epsilon[k, t - 1]  # 探索のためのサンプリング
        else:
          v[k, t - 1] = epsilon[k, t - 1]  # 探索のためのサンプリング
        # x を更新
        x = self.F(x, self.clamp_vel(v[k, t - 1]),self.delta_T)
        # ステージコストを追加
        S[k] += self.c(x) + self.param_gamma * u[t - 1].T @ np.linalg.inv(self.Sigma) @ v[k, t - 1]
      # 終端コストを追加
      S[k] += self.phi(x)
    # 各サンプルの情報理論的重みを計算
    w = self.compute_weights(S)
    # w_k * epsilon_k を計算
    w_epsilon = np.zeros((self.T, self.dim_u))
    for t in range(self.T):  # 時間ステップ t = 0 ~ T-1 のループ
      for k in range(self.K):
        w_epsilon[t] += w[k] * epsilon[k, t]
    # 入力シーケンスを平滑化するために移動平均フィルターを適用
    w_epsilon = self.moving_average_filter(xx=w_epsilon, window_size=10)
    # 制御入力シーケンスの更新
    u += w_epsilon
    # 最適軌道を計算
    optimal_traj = np.zeros((self.T, self.dim_x))
    if self.visualize_optimal_traj:
      x = x0
      for t in range(self.T):
        x = self.F(x, self.clamp_vel(u[t - 1]),self.delta_T)
        optimal_traj[t] = x
    # サンプリングされた軌道を計算
    sampled_traj_list = np.zeros((self.K, self.T, self.dim_x))
    sorted_idx = np.argsort(S)  # 状態コストでサンプルをソート、0番目が最良のサンプル
    if self.visualze_sampled_trajs:
      for k in sorted_idx:
        x = x0
        for t in range(self.T):
          x = self.F(x, self.clamp_vel(v[k, t - 1]),self.delta_T)
          sampled_traj_list[k, t] = x
    # 前回の制御入力シーケンスを更新(左に1ステップシフト)
    self.u_prev[:-1] = u[1:]
    self.u_prev[-1] = u[-1]
    # 最適制御入力と入力シーケンスを返す
    return u[0], u, optimal_traj, sampled_traj_list

モデル

def F(x_t: np.ndarray, v_t: np.ndarray, dt: float) -> np.ndarray:
  """calculate next state of the vehicle"""
  # get previous state variables
  x, y, yaw, vx_, vy_, wz_ = x_t
  vx, vy, wz = v_t
  #_ update state variable
  x_next = x + vx * dt
  y_next = y + vy * dt
  yaw_next = yaw + wz * dt
  vx_next = vx
  vy_next = vy
  wz_next = wz
  return np.array([x_next,y_next,yaw_next,vx_next,vy_next,wz_next])

コスト関数

  def c(self, x_t: np.ndarray) -> float:
    """calculate stage cost"""
    # parse x_t
    x, y, yaw, vx, vy, wz =x_t
    yaw = ((yaw + 2.0*np.pi) % (2.0*np.pi)) # normalize theta to [0, 2*pi]

    # calculate stage cost
    _, ref_x, ref_y, ref_yaw, ref_vx, ref_vy, ref_wz = self.get_nearest_waypoint(x, y)
    stage_cost = self.stage_cost_weight[0]*(x-ref_x)**2 + self.stage_cost_weight[1]*(y-ref_y)**2 + self.stage_cost_weight[2]*(yaw-ref_yaw)**2 + self.stage_cost_weight[3]*(vx-ref_vx)**2 + self.stage_cost_weight[4]*(vy-ref_vy)**2 + self.stage_cost_weight[5]*(wz-ref_wz)**2

    # # add penalty for collision with obstacles
    # stage_cost += self._is_collided(x_t) * 1.0e10

    return stage_cost

  def phi(self, x_T: np.ndarray) -> float:
    """calculate terminal cost"""
    # parse x_T
    x, y, yaw, vx, vy, wz =x_T
    yaw = ((yaw + 2.0*np.pi) % (2.0*np.pi)) # normalize theta to [0, 2*pi]

    # calculate terminal cost
    _, ref_x, ref_y, ref_yaw, ref_vx, ref_vy, ref_wz = self.get_nearest_waypoint (x, y)
    terminal_cost = self.terminal_cost_weight[0]*(x-ref_x)**2 + self.terminal_cost_weight[1]*(y-ref_y)**2 + self.terminal_cost_weight[2]*(yaw-ref_yaw)**2 + self.terminal_cost_weight[3]*(vx-ref_vx)**2 + self.terminal_cost_weight[4]*(vy-ref_vy)**2 + self.terminal_cost_weight[5]*(wz-ref_wz)**2
    # # add penalty for collision with obstacles
    # terminal_cost += self._is_collided(x_T) * 1.0e10
    return terminal_cost

ノイズのサンプリング

  def calc_epsilon(self, sigma: np.ndarray, size_sample: int, size_time_step: int, size_dim_u: int) -> np.ndarray:
    """sample epsilon"""
    # check if sigma row size == sigma col size == size_dim_u and size_dim_u > 0
    if sigma.shape[0] != sigma.shape[1] or sigma.shape[0] != size_dim_u or size_dim_u < 1:
        print("[ERROR] sigma must be a square matrix with the size of size_dim_u.")
        raise ValueError

    # sample epsilon
    mu = np.zeros((size_dim_u)) # set average as a zero vector
    epsilon = np.random.multivariate_normal(mu, sigma, (size_sample, size_time_step))
    return epsilon

重みの計算

  def compute_weights(self, S: np.ndarray) -> np.ndarray:
    """compute weights for each sample"""
    # prepare buffer
    w = np.zeros((self.K))
    # calculate rho
    rho = S.min()
    # calculate eta
    eta = 0.0
    for k in range(self.K):
      eta += np.exp( (-1.0/self.param_lambda) * (S[k]-rho) )

    # calculate weight
    for k in range(self.K):
      w[k] = (1.0 / eta) * np.exp( (-1.0/self.param_lambda) * (S[k]-rho) )
    return w

moving average filter

  def moving_average_filter(self, xx: np.ndarray, window_size: int) -> np.ndarray:
    """apply moving average filter for smoothing input sequence
    Ref. <https://zenn.dev/bluepost/articles/1b7b580ab54e95>
    """
    b = np.ones(window_size)/window_size
    dim = xx.shape[1]
    xx_mean = np.zeros(xx.shape)
    for d in range(dim):
      xx_mean[:,d] = np.convolve(xx[:,d], b, mode="same")
      n_conv = math.ceil(window_size/2)
      xx_mean[0,d] *= window_size/n_conv
      for i in range(1, n_conv):
        xx_mean[i,d] *= window_size/(i+n_conv)
        xx_mean[-i,d] *= window_size/(i + n_conv - (window_size % 2))
    return xx_mean

oter