Earth Mover's Distance(2)lpSolveを使ったC実装

Earth Mover's Distance (2)
EMD with lpSolve


2021/03/11
藤田昭人


だいぶん間が空いてしまいましたが…

実は Earth Mover's Distance の実装を巡って悪戦苦闘してました。 その顛末をダラダラを書き連ねた記事が予想外に長くなってしまったので、 要点のみをピックアップして再構成したものが本稿です。

今日 Earth Mover's Distance ついてはさまざまな実装が出回っていますが、 調べてみたところ Earth Mover's Distance を最初に提案した論文 "A Metric for Distributions with Applications to Image Databases" の著者 Yossi Rubner が作成した(と思われる) Code for the Earth Movers Distance (EMD) に掲載されているC実装がオリジナルのようです。 ですが、何分にも20年以上前の1998年に作成されたコードなので、 現在のCコンパイル環境では正しく動いてくれません。

やむなく、代わりとして使える リファレンス実装のコードを探すところから 本稿の作業を始めました。


Rによる EMD の実装

さまざまな条件でググってみたところ…

前述のオリジナルのC実装にも言及している 次のブログ記事を見つけました。

aidiary.hatenablog.com

この記事では EMD が (前回紹介した) 線形計画法の輸送問題を応用した指標であることを説明したのちに、 前述のオリジナル実装に付属するデモコード example1.c を例題として Rと Python の実装例を紹介しています。

記事によれば EMD は次の公式で表されますが…


{\rm EMD} (P, Q) = \frac{\displaystyle \sum_{i=1}^m \displaystyle \sum_{j=1}^n d_{ij} f_{ij}^*}{\displaystyle \sum_{i=1}^m \displaystyle \sum_{j=1}^n f_{ij}^*}

この式の分母部分の  \displaystyle \sum_{i=1}^m \displaystyle \sum_{j=1}^n f_{ij} が輸送される荷物の重さの総和、 分子部分の  \displaystyle \sum_{i=1}^m \displaystyle \sum_{j=1}^n d_{ij} f_{ij} が輸送問題の解なんだそうです。 つまり、この式は輸送問題を解いた後、 移動させる荷物の総重量で割って 単位重量あたりの作業量を計算していることになります。

f:id:Akito_Fujita:20210309195233p:plain

上記は記事から拝借してきた オリジナル実装の example1 の条件を示した図です。 この条件で輸送問題を解くと、 荷物は次のように移動するとコストは最小となります。

移動元 移動先 荷物量
 P_{1}  Q_{1} 0.4
 P_{2}  Q_{1} 0.1
 P_{2}  Q_{2} 0.2
 P_{3}  Q_{3} 0.2
 P_{4}  Q_{2} 0.1

この場合、移動する荷物量の総和は 1.0 となるので、 輸送問題の解である 160.542763 が EMD の解になります。


Rパッケージの lpSolve

Rでの EMD 実装に使われている輸送問題のソルバー lp.transport は lpSolve パッケージに収録されています。詳細は lpSolve.pdf を見ていただくとして、インターフェースのみを次に抜粋します。

名前 説明
lp lp_solve線形・整数計画法システムへのインターフェース
lp.assign 割当問題を解くために特化した lp_solve へのインタフェース
lp.transport 輸送問題を解くために特化した lp_solve へのインターフェース
make.q8 8クイーン問題のための疎な制約行列の生成
lp.object lpオブジェクトの構造
print.lp lpオブジェクトをプリントするメソッド

線形計画法のソルバーとしては 汎用の lp割当問題 に特化した lp.assign輸送問題 に特化した lp.transport の 3つのインターフェースが定義されています。 いずれのインターフェースも lp.object で パラメータの受け渡しをしていますが、 その内訳を次に示します。

メンバー 説明
direction 入力された最適化の方向
x.count 目的関数の変数数
objective 目的関数の係数のベクトル
const.count 入力された制約の数
constraints 制約|入力されたとおりの制約行列(lp.assignやlp.transportでは返されません)
int.count 整数変数の数
int.vec 整数変数のインデックスのベクトル
objval 最適時の目的関数の値
solution 最適な係数のベクトル
num.bin.solns 返された解の数を数値で表示します
status 数字のインジケータです。0 = 成功, 2 = 実現可能な解決策がない

このパッケージ、 実はオープンソース線形計画法ソルバーCライブラリである lpSolve そのものです。 ですが、このライブラリのAPIはエントリーが多くて少々煩雑なので、 複数のエントリーを集約した3つのRコードを パッケージのインターフェースとして定義されているようです。

ちなみにライブラリAPIをそのままインタフェース化した lpSolveAPI も存在します。詳細は lpSolveAPI.pdf を確認してください。両方のパッケージともCコードは同じです。


線形計画法ソルバーライブラリ lpSolve

Rのパッケージが元はCライブラリだとわかって 「それじゃ直接コールすればいいじゃん」 と考えるのは今どきのプログラマの流儀ではないのかもしれませんが(笑)

ソルバーライブラリを探してみたところ、 なんと sourceforge(懐かしい)で、 今もメンテナンスが続いているようです。 ドキュメントを次に示します。

lpsolve.sourceforge.net

で、ライブリラリのビルドの手順を 書こうかと考えたのですが、 素敵なことに下記の Qiita の記事に 手順が書いてあったので、 こちらを参考にしてください。

qiita.com

この記事は lpSolve を Python に 取り込む方法を紹介していますが、 本家の lpSolve API バインディングを 使用しているようです。 元のブログ記事では Rパッケージを Python から使う事例が 紹介されてたのは 輸送問題のAPIである lp.transport が 使いたかったからですかねぇ。 「それ、むっちゃ効率が悪くないか?」 と思うのは僕が老人だからでしょうかねぇ?


Cライブラリの lpSolve を使って輸送問題を解く

ライブラリもビルドできたところで、 輸送問題を解くCプログラムを物色し始めました。 「わりと一般的だろうなぁ…」と想像して lpSolve を呼び出すCコードのサンプルを ググりまくったのですが、 日本語はもちろんのこと、英語でも見つからない。 「今どきのプログラマはもうCは書かないのかな?」 などと思いつつ、 輸送問題の解説 と先程の図をつきわせて 「lpSolve を使って輸送問題を解く」 勉強をしました。 (大学の情報演習みたいだ)

そもそも lpSolve で輸送問題 (というか最適化問題)を解くためには 目的関数と制約条件を設定する必要があります。 輸送問題の目的関数と制約条件を簡単に説明します。

f:id:Akito_Fujita:20210309195233p:plain

前述の図を再掲します。この図を見ると どうしても分布Pと分布Qの 都合7つのマルに目がいってしまうのですが、 輸送問題で注目するのはマルを繋ぐ都合12本の矢印だったりします。

つまり、輸送前は分布Pには図に示すとおり荷物が積まれており、分布Qには荷物は全くない状態だと考えます。 そして、輸送後は分布Pには荷物は全くなく、分布Qには図に示すとおりに荷物が積まれている状態になります。 では、輸送中は?というと都合12本の矢印に荷物が乗っている、つまり荷物が輸送されている状態です。 この時にコストを最小にするためにはどの矢印に荷物をどれくらい載せるべきか?を考えるのが 前回 紹介した「ヒッチコックの輸送問題」なんだそうです。

ここでコストと呼んでるのは分布内の各所の間の距離が違うからです。 例えばP1ーQ1間の距離とP1ーQ2間の距離は異なりますが、 この距離というのはどんな場合にも一定なので定数として扱えます。 各矢印毎のコストは距離と運ぶ荷物量を掛けた値となり、 輸送全体のコストは全ての矢印毎のコストを足しあわせたものと考えます。 前述の EMD の公式の分子部分はこの計算していますが、 これが輸送問題の「目的関数」でもあり、 輸送全体のコストが最小になるように 各矢印の間での運ぶ荷物量の調整を行います。

次に輸送問題の「制約条件」をですが、 これは分布Pと分布Qの各所が扱える荷物量に着目します。 例えば、P1はQ1、Q2、Q3に荷物を発送できますが、 その3箇所に発送する荷物量の合計は 0.4 以下に制約されます。 同様にP2、P3、P4も保有する荷物量で制約されます。 一方、Q1に着目するとP1、P2、P3、P4から荷物を受け取れますが、 その合計が 0.5 以上でなければなりません。 同様にQ2、Q3も受け取る荷物量の下限が制約されます。 このように輸送問題では分布の構成要素の総数分、 この例題では7つの「制約条件」が設定されます。

以上のように「目的関数」と「制約条件」が決められれば、 あとは lpSolve がこの問題を解いてくれます。 例題にそって条件を設定して lpSolve を呼び出す サンプルプログラム を末尾に掲載します*1。 先程ビルドした lpSolve からライブラリとヘッダーをコピーすれば、 次の手順をコンパイルできます。

$ ls inc
lp_Hash.h   lp_lib.h    lp_mipbb.h  lp_utils.h
lp_SOS.h    lp_matrix.h lp_types.h
$ ls lib
liblpsolve55.a
$ c++ -I./inc -g -Wno-macro-redefined -Wno-format -Wno-c++11-compat-deprecated-writable-strings -c example.cpp
$ c++ -L./lib -llpsolve55 -o example example.o
$ ./example
obj_val: 160.542763
variables:
0.400000
0.000000
0.000000
0.100000
0.200000
0.000000
0.000000
0.000000
0.200000
0.000000
0.100000
0.000000
$ 

ようやく冒頭のブログ記事と同じ EMD の計算値 160.542763 が確認できました。 続く variables として表示されてる12の数値は各矢印毎の輸送する荷物量、 つまり分布Pから分布Qへ荷物を輸送する差配を示しています。

しかし、この解を見てると、なんだか人間の代わりに コストが最小になるように lpSolve が考えてくれたように僕には見えてしまいます。 ある意味では人工知能の正体を覗いている感覚になります。 やはり「人工知能」という言葉やそのイメージは、 こういった数学的な解法を巧みにカモフラージュしてしまう 側面が否めないなぁと感じてしまいます。


ここまでのまとめ

lpSolve は1990年代から存在する歴史あるCライブラリで、 さまざまなプログラミング言語へのバインディングを始め、 機能が充実しています。例えば example.cpp の85行目をコメントアウトしてもらうと、 次のように目的関数と制約条件を書いた lp フォーマットのファイルが生成されます。

/* Objective function */
min: +109.927248669 C1 +97.2830920561 C2 +352.90083593 C3 +211.955183942 C4 +195.971936766 C5 +348.09481467 C6
 +244.180261283 C7 +115.429632244 C8 +254.909787964 C9 +141.435497666 C10 +52 C11 +334.752147118 C12;

/* Constraints */
+C1 +C2 +C3 <= 0.4;
+C4 +C5 +C6 <= 0.3;
+C7 +C8 +C9 <= 0.2;
+C10 +C11 +C12 <= 0.1;
+C1 +C4 +C7 +C10 >= 0.5;
+C2 +C5 +C8 +C11 >= 0.3;
+C3 +C6 +C9 +C12 >= 0.2;

このファイルはライブラリに付属する lp_solve コマンドの入力として使えます。 なので JavaScript と接続するには、 この lp フォーマットのファイルを生成して、 子プロセスで lp_solve コマンドを呼び出す方法が 最もお手軽かな?などと考えています。

ともあれ…

Rによる EMD 実装をなぞることにより Earth Mover's Distance は「輸送問題の解を正規化した指標」である ことがわかったことは収穫でした。 もっとも、今どき「輸送問題」をCで解こうとするとある穴全部にハマる… まぁ僕個人の良い勉強になったということにしておきましょう。

しかし、今どきのプログラミング言語のトレンドである 過去に実装されたコードをブラックボックスのまま取り込む風潮には、 正直「これで良いのか?」などと考えていたら、 次の書籍があと数日で出版されることを知りました。

www.hanmoto.com

僕の心配は杞憂で、この3週間は「くたびれ儲け」だった というオチが付いたところで本稿を一旦締めます。

以上

付録: example.cpp

#include <stdio.h>
#include <stdlib.h>
#include "lp_lib.h"

double f1[4][3] = { { 100,  40,  22},
            { 211,  20,   2},
            {  32, 190, 150},
            {   2, 100, 100} };
double f2[3][3] = { {   0,   0,   0},
            {  50, 100,  80},
            { 255, 255, 255} };
double w1[4] = { 0.4, 0.3, 0.2, 0.1 };
double w2[3] = { 0.5, 0.3, 0.2 };

#include <math.h>

double dist(double *F1, double *F2)
{
  double dx = F1[0] - F2[0];
  double dy = F1[1] - F2[1];
  double dz = F1[2] - F2[2];
  return(sqrt(dx*dx + dy*dy + dz*dz));
}

double a[13];

double b[13] = { 0,  //  <= 0.4
         1,  1,  1,
         0,  0,  0,
         0,  0,  0,
         0,  0,  0 };
double c[13] = { 0,  //  <= 0.3
         0,  0,  0,
         1,  1,  1,
         0,  0,  0,
         0,  0,  0 };
double d[13] = { 0,  //  <= 0.2
         0,  0,  0,
         0,  0,  0,
         1,  1,  1,
         0,  0,  0 };
double e[13] = { 0,  //  <= 0.1
         0,  0,  0,
         0,  0,  0,
         0,  0,  0,
         1,  1,  1 };


double f[13] = { 0,  //  <= 0.5
         1,  0,  0,
         1,  0,  0,
         1,  0,  0,
         1,  0,  0 };
double g[13] = { 0,  //  <= 0.3
         0,  1,  0,
         0,  1,  0,
         0,  1,  0,
         0,  1,  0 };
double h[13] = { 0,  //  <= 0.2
         0,  0,  1,
         0,  0,  1,
         0,  0,  1,
         0,  0,  1 };

void
init()
{
  a[0] = 0;
  for (int i = 0; i < 4; i++) {
    for (int j = 0; j < 3; j++) {
      int n = i*3 + j + 1;
      a[n] = dist(f1[i], f2[j]);
    }
  }
}

int
main()
{
  init();

  lprec *lp;
  int ret;

  lp = make_lp(0, 12);
  if (lp == NULL) exit(1);

  ret = set_obj_fn(lp, a);
  if (ret == 0) exit(1);

  add_constraint(lp, b, LE, 0.4);
  add_constraint(lp, c, LE, 0.3);
  add_constraint(lp, d, LE, 0.2);
  add_constraint(lp, e, LE, 0.1);
  add_constraint(lp, f, GE, 0.5);
  add_constraint(lp, g, GE, 0.3);
  add_constraint(lp, h, GE, 0.2);

  //write_lp(lp, "example.lp");

  set_verbose(lp, 1); // CRITICAL; NORMAL = 4; FULL = 6;
  ret = solve(lp);
  if (ret != 0) {
    printf("status: %d\n", ret);
    print_lp(lp);
    exit(1);
  }

  double obj_val = get_objective(lp);
  printf("obj_val: %f\n", obj_val);

  double var[100];
  get_variables(lp, var);
  printf("variables:\n");
  for (int i = 0; i < 12; i++) {
    printf("%f\n", var[i]);
  }

  delete_lp(lp);
  exit(0);
}

*1:.cpp となってますが、 これはコードの途中で変数宣言をしたかったからだけで、 頭から最後までCで記述しています。