スチールウールの活動記録

初心者に向けて(私もだけど)電子工作やプログラミングや関連することをいろいろ。

caffeのサンプルを動かす

techblog.yahoo.co.jp

qiita.com

この2つのサイトを見ながらサンプルだけ動かしてみた。

caffeのトップディレクトリにいる想定で進めていく。

サンプルをコマンドで持ってくる

wget http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz
tar zxvf 101_ObjectCategories.tar.gz

モデルファイルを持ってくる

cd models/bvlc_reference_caffenet/
wget http://dl.caffe.berkeleyvision.org/bvlc_reference_caffenet.caffemodel

get_ilsvrc_aux.shを実行する

cd ../../data/ilsvrc12
./get_ilsvrc_aux/sh

classify.pyを実行

cd ../../python
python classify.py --raw_scale 255 ../101_ObjectCategories/airplanes/image_0001.jpg ../result.npy

上記でエラーが起きたら caffe/python/caffe/io.pyの253-254行目の

if ms != self.inputs[in_][1:]:
    raise ValueError('Mean shape incompatible with input shape.')

if ms != self.inputs[in_][1:]:
    print(self.inputs[in_])
    in_shape = self.inputs[in_][1:]
    m_min, m_max = mean.min(), mean.max()
    normal_mean = (mean - m_min) / (m_max - m_min)
    mean = resize_image(normal_mean.transpose((1,2,0)),in_shape[1:]).transpose((2,0,1)) * (m_max - m_min) + m_min
    #raise ValueError('Mean shape incompatible with input shape.')

と書き換える。

show_result.pyを書いて、実行する

#! /usr/bin/env python
# -*- coding: utf-8 -*-
import sys, numpy

categories = numpy.loadtxt(sys.argv[1], str, delimiter="\t")
scores = numpy.load(sys.argv[2])
top_k = 3
prediction = zip(scores[0].tolist(), categories)
prediction.sort(cmp=lambda x, y: cmp(x[0], y[0]), reverse=True)
for rank, (score, name) in enumerate(prediction[:top_k], start=1):
    print('#%d | %s | %4.1f%%' % (rank, name, score * 100))

実行すれば分類結果が表示される。

めんどくさいのでスクリプト

#!/bin/sh
cd $CAFFE_DIRECTORY/python
if [ $# != 1 ]; then 
    echo "Argument error"
else
    python classify.py --raw_scale 255 $1 ./result.npy
    python show_result.py ../data/ilsvrc12/synset_words.txt result.npy $1
fi