Python, scikit-learn

【scikit-learn】ライブラリ付属のデータセットまとめ – 概要、使い方を解説


Pythonの機械学習ライブラリであるscikit-learnには、付属の実験用データセットが用意されています。

これを利用することで手軽に機械学習手法の実験が可能です。

この記事では、scikit-learnに付属する全データセットの概要と使い方を解説します。

scikit-learn付属のデータセット概要

まずはscikit-learnに付属するデータセットの基本的な情報を説明します。

小規模、大規模のデータセットが2種類

scikit-learnで利用可能なデータセットは大きく分けて以下の2種類です。

  • Toy datasets(7個)
    件数の少ないデータセット(ダウンロードなしで利用可)
  • Real world datasets(9個)
    件数の多いデータセット(ダウンロードが必要)

これらのデータセットを用途や実験の目的に応じて使い分けることができるため、scikit-learn単体で非常にバラエティーに富んだ実験を行うことができます。

「Toy datasets」の概要、使い方

Toy datasetsはデータの件数が少なめですが、ダウンロードなしで利用することが可能です。

全部で7つあるデータセットはsklearn.datasetsクラスからインポートして利用します。使用例を以下に示します。

>>> from sklearn.datasets import load_boston
>>> loaded = load_boston()
・・・
>>> x = loaded.data    # 学習データ
>>> y = loaded.target  # 正解データ
・・・
>>> print(x.shape)
(506, 13)
>>> print(y.shape)
(506,)

データセットによってデータ件数や用途は異なるため、その点を考慮して使用して下さい。

なお、各データセットについての説明は以下の通りです。

load_boston()

ボストン市内の地域ごとの犯罪率や税率などの情報(13項目)を入力データとしてその地域の住宅価格を予測します。

load_iris()

花弁や萼片の長さなどの情報(4項目)を入力データとしてアヤメの種類を予測します。

load_diabetes()

糖尿病患者の検査結果(10項目)を入力データとして1年後の病気の進行状況を予測します。

load_digits()

手書きの数字の画像(8×8)に0~9のどの数字が描かれているかを予測します。

load_linnerud()

成人男性によるスポーツテスト(3種類)の結果を入力データとしてその人の生理学的な特徴(3種類)を予測します。

load_wine()

アルコール度数や色味などの情報(13項目)を入力データとしてワインの種類を予測します。

load_breast_cancer()

癌の診断結果(30項目)を入力データとして良性の癌か悪性の癌かを予測します。

「Real world datasets」の概要、使い方

Real world datasetsはToy datasetsよりもデータ件数の多いデータセットですが、利用する前にダウンロードを行う必要があります。

全部で9つあるデータセットはsklearn.datasetsクラスからインポートして利用します。使用例を以下に示します。

>>> from sklearn.datasets import fetch_olivetti_faces
>>> loaded = fetch_olivetti_faces()
downloading Olivetti faces from https://ndownloader.figshare.com/files/5976027 to C:\Users\user\scikit_learn_data
・・・
>>> x = loaded.data    # 学習データ
>>> y = loaded.target  # 正解データ
・・・
>>> print(x.shape)
(400, 4096)
>>> print(y.shape)
(400,)

ダウンロードは最初にデータセットをインポートしたプログラムの実行時に自動で行われ、それ以降のプログラム実行時には行われません。

なお、各データセットについての説明は以下の通りです。

fetch_olivetti_faces()

それぞれ異なる人物の顔写真(64×64)から写っている人物を予測します。

fetch_20newsgroups()

ニュース記事の本文(テキストデータ)からどのトピックに該当するかを予測します。

fetch_20newsgroups_vectorized()

ニュース記事の本文(ベクトルデータ)からどのトピックに該当するかを予測します。

fetch_lfw_people()

有名人の顔写真(62×47)から写っている人物を予測します。

fetch_lfw_pairs()

2枚組の有名人の顔写真(62×47)が同一人物のものかを予測します。

fetch_covtype()

土地の標高や傾斜、土壌の種類などの情報(54項目)を入力データとしてその土地にある森林の木の種類を予測します。

fetch_rcv1()

ニュース記事より抽出した特徴量(ベクトルデータ)を入力データとしてどのカテゴリに属するかを予測します。

  • 用途:分類(103クラス)
  • データ件数:804414
  • データ次元数:47236
  • ドキュメント:sklearn.datasets.fetch_rcv1

fetch_kddcup99()

持続時間やプロトコルタイプなどの情報(41項目)を入力データとしてネットワークへの攻撃かどうか、その攻撃の種類は何かを予測します。

fetch_california_housing()

カリフォルニア州の地域ごとの人口や住人の収入などの情報(8項目)を入力データとして住宅価格を予測します。