数を足し合わせる際の情報落ちへの対処法

大量の数を足し合わせる際に情報落ちをしてしまう問題への対処法メモ。
少数の数の足し合わせであれば情報落ちも無視していいレベルの場合もあるが、大量の数を足し合わせる際には無視できない数となる。また、expスケールなどの大きな数を取り扱うに当たっても注意が必要である。

logスケール変換

大量の数の下記のような和を考える。

{ \displaystyle
L = \sum^{n}_{k=1} l(k)
}

このまま計算してしまうと情報落ちが発生してしまうため、各{l(k)}{\log}スケールに変換して和の計算を行う。

{ \displaystyle
\begin{align}
\log L &= \log \left( \sum^{n}_{k=1} l(k) \right)  \\\\
&= \log l(k^{*}) + \log \left\{ \sum^{n}_{k=1} \exp \left( \log l(k) - \log l(k^{*}) \right) \right\}
\end{align}
}

ただし、 k^* = {\rm argmax}_k \ l(k)としている。
ここで、第一項は最大値をそのまま計算した値となり、第二項における \exp の計算は指数部分が全て負の値となり、情報落ちを抑えて計算することができる。

Pythonコード

この計算を計算するPythonのfunctionを下記のように作成した。

def log_sum(l_log):
    '''calculate sum of log'''
    max_log = max(l_log)
    l_log_minus = l_log - max_log
    out_value = max_log + np.log(np.sum(np.exp(l_log2_minus)))
    return out_value