schedule2019-03-14

P5jsでk-meansのクラスタリング過程を可視化する

P5jsを使ってk-meansのクラスタリング過程の可視化に挑戦しました。 また、とてもざっくりとした解説を載せます。

k-meansとは?

データを与えられたk個のクラスタに分類するアルゴリズムです。

クラスタリングする手順は一般的に以下の通り。

  1. データをランダムにk個のクラスターに割り振る。
  2. クラスターの中心を求める。
  3. 各クラスタの中心との距離を求め、一番近い中心のクラスタに割り振ります。
  4. 2と3を繰り返してクラスターが変動しなくなれば終了。

step2とstep3のイメージ

step2

step2からstep3へクラスタが更新している。

k-meansの長短

機械学習の中では教師なし学習にあたります。 つまり、ラベル付けなどの作業をしなくても良い感じにデータを分けてくれます。

ただし、以下の点に注意が必要です。

  • 最終的なクラスタは初期値に依存する。
  • クラスタ数がデータに合わないことがある。

ここでは簡単な説明にとどめます。以下の記事が詳しく解説しています。

シミュレーション

実際に動かして体験してみましょう。

リセットを押すと大体600個ほどの点群が出来ます。 ステップを押すと「シードの移動」と「クラスタリング」を繰り返して収束していきます。

それぞれのクラスターのサイズと位置(x, y)、クラスタ数を表示ししています。

初期配置は1. データをランダムにk個のクラスターに割り振る。ではなく、ランダムに初期の点を置いています。 データセットは、前回作成した2次元の混合分布を使ってそれっぽい分布にしています。

ソースコード全文

読みづらくて申し訳ない。

キャンバスと表示制御部分のhtml

<!-- p5js -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.7.3/p5.min.js"></script>

<script src="/images/posts/115/k-means.js"></script>

<div class="p5js">
  <div class="center">
    <button type="button" class="button" onclick="retry();">リトライ</button>&nbsp;
    <button type="button" class="button is-primary" onclick="step();">▶ ステップ</button>&nbsp;
  </div>
  <div  class="center">
    シード数:<span id="count">3</span>&nbsp;
  </div>
  <div id="canvas"  class="has-text-centered"></div>
  <div class="center" id="parameters"></div>
</div>


<style>
.p5js .center{
  text-align: center;
  margin: 10px;
}
.p5js .button{
  text-align: center;
  margin: 10px;
}
</style>

k-means.js

let data = []; // 分類したデータ
let points = []; // 全てのポイント
let seeds = [];

let frame_rate = 10;
let parameter = [];

let canvas_size = {
  width: 400, height: 400
} //px

let bg_color = '#FFFCDB';
let cluster_colors = ['#A40000', '#0075A9', '#007130', '#A4005B', '#B7AA00', '#100964'];

let cluster_size = 3;
let next_step_flg = false;

// 再描画
function retry() {
  ini();
}

// step
function step() {
  drawView();
}

// 初期化
function ini() {
  parameter = [];
  seeds = iniSeed();
  background(color(bg_color));

  data = createData();
  points = join_dataset(data);
  data = [points];
  drawCluster();
  next_step_flg = true;
}


// シードの初期値
function iniSeed() {
  array = new Array(cluster_size);
  for (let i = 0; i < cluster_size; i++) {
    let x = random(0, canvas_size.width);
    let y = random(0, canvas_size.height);
    array[i] = [x, y];
  }
  return array;
}

// ランダムな混合分布のデータセットを作成
function createData() {
  let clusters = [];
  for (let i = 0; i < 6; i++) {
    let size = int(random(100, 200));
    let x_mean = random(0, canvas_size.width);
    let y_mean = random(0, canvas_size.height);
    let x_stddiv = random(20, 50);
    let y_stddiv = random(20, 50);

    let c = createGaussianDistribution(size, x_mean, x_stddiv, y_mean, y_stddiv);
    clusters.push(c);
  }

  return clusters;
}

// ガウシアン分布から乱数で生成する
function createGaussianDistribution(size, x_mean, x_stddiv, y_mean, y_stddiv) {
  // パラメータ表示用
  parameter.push({
    size: size,
    mean: {
      x: x_mean, y: y_mean
    },
    stddiv: {
      x: x_stddiv, y: y_stddiv
    }
  });
  // 点群のリスト
  let array = [];
  for (let i = 0; i < size; i++) {
    // ガウシアン分布のランダムな値
    let x = randomGaussian(x_mean, x_stddiv);
    let y = randomGaussian(y_mean, y_stddiv);
    // 範囲外のポイントは除く
    if (isOut(x, y)) continue;
    array.push([x, y]);
  }
  return array;
}

// 描画範囲外か判定
function isOut(x, y) {
  if (x < 0 || canvas_size.width < x ||
    y < 0 || canvas_size.height < y) {
    return true;
  }
  return false;
}

function join_dataset(clusters) {
  array = [];
  clusters.forEach(function (element) {
    array = array.concat(element);
  });
  // 元の3次元配列を崩さない様にするため
  return array;
}


// p5js
function setup() {
  let canvas = createCanvas(canvas_size.width, canvas_size.height);
  canvas.parent('canvas');
  // フーレームレートを1/1secにする
  frameRate(frame_rate);

  background(color(bg_color));

  ini();
  drawView();
}


function drawView() {
  // 描画
  // 1ステップごと描画
  background(color(bg_color));

  if (next_step_flg) {
    compairePoints();
    drawCluster();
    drawSeedPoints();
  } else {
    moveSeeds();
    displayParameters();
    drawCluster();
    drawSeedPoints();
  }
  next_step_flg = !next_step_flg;
}

// データセットの描画
function drawCluster() {
  for (let i in data) {
    drawPoints(data[i], cluster_colors[i]);
  }
}

// 点の描画
function drawPoints(array, color) {
  let r = 6;
  noStroke();
  fill(color);
  for (let point of array) {
    ellipse(int(point[0]), int(point[1]), r, r);
  }
}

// シードの描画
function drawSeedPoints() {
  let r = 15;
  for (let i in seeds) {
    stroke(cluster_colors[i]);
    strokeWeight(4);
    let point = seeds[i];
    fill('#00000');
    ellipse(int(point[0]), int(point[1]), r, r);
  }
}

function compairePoints() {
  // 空にする。
  data = new Array();
  for (let i in seeds) {
    data.push([]);
  }
  for (let point of points) {
    let min = 100000;
    let j = 0;
    for (let i in seeds) {
      let _min = diffelense(point, seeds[i])
      if (min < _min) continue;
      j = i;
      min = _min;
    }
    data[j].push(point)
  }
}

// 点の距離
function diffelense(p1, p2) {
  return Math.sqrt(Math.pow(p1[0] - p2[0], 2) + Math.pow(p1[1] - p2[1], 2))
}

function moveSeeds() {

  for (let i in data) {
    let sum = [0, 0]

    if (data[i].length < 1) {
      // ポイントが無ければランダム
      seeds[i][0] = random(0, canvas_size.width);
      seeds[i][1] = random(0, canvas_size.width);
      continue;
    }

    for (let point of data[i]) {
      sum[0] += point[0];
      sum[1] += point[1];
    }
    seeds[i][0] = sum[0] / data[i].length;
    seeds[i][1] = sum[1] / data[i].length;
  }
}


// パラメータの表示
function displayParameters() {
  let html = '';
  for (let i in seeds) {
    let txt = '<p style="color:' + cluster_colors[i] + ';">';
    txt += 'サイズ:' + data[i].length + '&nbsp;';
    txt += '位置:(' + seeds[i][0].toFixed(1) + ', ' + seeds[i][1].toFixed(1) + ')';
    txt += '</p>';
    html += txt;
  }
  document.getElementById('parameters').innerHTML = html;
}