主要なDeep Learning ライブラリの一覧

0. Deep Learning ライブラリ 一覧の概要と記事の構成

このページでは,近年 (2022年 現在)よく使用されている代表的な Deep Learning 向けのライブラリ・フレームワークを,カテゴリごとに整理する.

1節以降,汎用的ライブラリ,ビジョン向けライブラリ,NLP向けライブラリの,3カテゴリに分けて紹介する:

  • 1節 汎用的な Deep Learning ライブラリ
    • 1.1 2大 DLライブラリ
      • PyTorch, lightning
      • Tensforflow2, Keras
    • 1.2 第3の候補:JAX系ライブラリ:Flax や Haiku など
  • 2節 コンピュータビジョン向けライブラリ
    • 2.1 大きな実装コードベース.OpenMM など
    • 2.2 ビジョンのタスク・モデル別
    • 2.3 Tranformer 系 (ViT, DETR )
  • 3節 NLP向け
    • HugginFace
  • 4節 その他

公的な研究・論文向けの,無償で使用できるものを中心に一覧化してある.商用使用を目的とするライブラリや,クラウドデプロイ目的の商用サービス(例えばAmazon SageMakerなど)は,このページではリストアップしていない.

※ この記事では,「ライブラリ」と「フレームワーク」の2つを,便宜上,同じ意味であるとグループ化し,両者を区別せずにライブラリとまとめて呼びたい.この記事で登場するものには,ライブラリなのかフレームワークなのかの境界線が怪しいものも多く,いちいち呼び分けるの面倒だからである.また「ツールキット」もその中に含める.

1. 代表的な Deep Learning ライブラリ

1.1 2大 DLライブラリ

汎用的な,特定のタスクやライブラリに特化していないディープラーニング向けのライブラリでは,PyTorchTensorFlow (+ Keras)が,特に研究開発向けでは2強であった (2021年ごろまで).

  • PyTorch 【PyTorchの紹介記事
    • Pytorch-Lightning:軽量・高速化されたPyTorchラッパー.2021以降よく使われている印象.
    • Catalyst:軽量・高速化されたPytorchラッパー.学習ループを各関数で分離分割管理でき,関数型プログラミング的に短く書ける.
    • Pytorch Geometric:GraphNNや,測地線,メッシュ幾何など幾何系処理全般むけ.
  • TensorFlow2
    • Keras:コーディングに癖があるTensorflow本体と違い,Kerasはシンプルなインターフェースであり,PyTorchのようにさくさくコードを書ける. 
    • Tensorflow Lite :エッジ計算・モバイル向け.
    • Tensorflow.js:ブラウザ向け.

これらに加えて MXNet 勢,というのが,2021年前半の状況だった.そこに,Google界隈から「JAXを用いた深層学習フレームワーク」が登場しはじめる.

※ [2022年6月末追記] PyTorch 1.11から,β版として,JAXライクに,gradやvmapなどを使えるfunctorchモジュールが追加された.JAXの影響の波及が,ついに本格的になってきたといえる.

1.2 第3の定番へ: JAXを元にした深層学習ライブラリ

2021年ごろから「Google JAX (1.2.1節)を用いた深層学習ライブラリ」が第3の定番DLライブラリの主軸候補として,研究者中心に普及しはじめている (1.2.2節).JAX をバックエンドに用いた深層学習ライブラリを用いると,JAXが提供する関数 (vmap,grad, vjpなど) による,「コンポーザブルな関数変換」の恩恵を,科学技術計算・深層学習のコーディングにおいても得られる (JAXのgithubの説明などを参照のこと).

ViT(Vision Transformer)やMLP-mixerの,Googleによる公式コードgithubレポジトリはJAX/Flaxベースであるなど,最先端研究者だと,JAXも触る必要性が少しずつ出てきている状況である.

1.2.1 JAXについて:改良版AutogradとXLA(機械学習特化のJITコンパイラー)

JAXは,「深層学習ソフトウェアのハイパフォーマンス化,およびその容易な短いコードでの実現」を特に意識した,Googleが中心で開発されている機械学習向けの数値計算ライブラリである.

JAX は「改良版Autograd」と「XLA(機械学習向けのコンパイル最適化)」の2つから構成されいる.シンプルな短いコードで済みながら,高性能な機械学習向けソフトウェアを開発できる.要するに,PythonとNumpyを,簡単な追加コードだけで,GPU・TPUむけに,数値計算や微分計算を高速化してくれるものがJAXである.

JAXでは,Google独自の改良版autogradが使用されている.よって,先行のPyTorchなどに含まれる(通常版の)autogradよりも,自動微分の性能が良くなっており,深層学習に力を発揮する.

また,XLAでは,numpyを含むコードを,Just-in-time コンパイルしてくれるので「関数単位の高性能化・省メモリ化」をXLAにお任せできる.具体的には,XLAは,各関数を,簡単な記述を追加するだけで「自動並列化・自動ベクトル化・JITコンパイル」を行うことができ,ソフトウェアのパフォマーンスを向上させてくれる(初めての人には,ドキュメントのQuick Startが参考になる).

よって,JAXにおいても,XLAを使用して,コーディングの負担が少ないまま,機械学習ソフトウェアの高性能化が達成できる(※ XLAと類似するものに,Anacondaの「numba」によるnumpy計算のJITコンパイルがある.numbaでも@jitと各関数に書くだけでJITコンパイルが実行できる).

JAXについてより詳しくは,公式githubの「What is JAX?」の節を参照のこと.

1.2.2 JAXベースのシンプルなDLライブラリ:HaikuとFlax

2021以降,研究者中心に,JAXベースのDeep Learningライブラリを使ったコードをよく見るようなり,それまでの2大定番(Pytorch Tensfor flow)と合わせて,第3の定番と化しつつある.

具体的は,以下の,HaikuFlaxTraxなどが,研究発表でも公開コードで見る機会が増えてきた(特にFlaxが普及しつつある):

  • Haiku(Google DeepMind):
    • DeepMindの Sonnet(Tensorflow2ベース)を,JAXベースのライブラリにしてパワーアップさせたものが,このHaikuである.
    • よって,DeepMindチームの先端的な基礎研究成果は,最近だとこのHaikuでコードが公開されている.
  • Flax (Google):
    • Haikuとは別に,マウンテンビューのGoogle本体やResearchチームによって作られているのがFlaxである.
    • 例えばViT(Vision Transformer)の公式コードは,JAX/Flaxである.つまり,ViTやMLP-Mixerなどの公式コードを研究で扱うには,まずFlaxなのである(とはいえ,その後,Pytorch/visionやTensforflow, huggingface版も用意されている)
    • 一方で,DETRは,facebookから提案されたのでpytorchであるなど,google以外のTransformer系の先端研究は,まだJAXベースではない.あくまで,Google界隈でのJAX(+Flax)の使用が増えている段階である(※ 2022年3月執筆).

JAXベースのこれらのライブラリの方が,シンプルで部品的(modular)なコードで,高速かつスケールする深層学習を行える.Haikuは特にそうだが,Tensorflowベースなので,元々Tensorflowだったものを簡単にHaikuのコードへ移行できる(= Google同士なので,そうなるように設計されている).

HaikuもFlaxも,両者は「コード行数を増やさず,短いコードで機械学習の高速化・スケール化ができる」という目的のJAXが元になっている上に,以前から存在するPytorch lightningなどの「軽量でシンプル化して,なおかつ部品的に書ける」ライブラリの理念もアップデートしている.

HaikuとFlaxは,オブジェクト名こそ結構違うが(※ 全結合層の名前は,HaikuはDense()で,FlaxはLinear()),共にJAX基盤で,なおかつ達成したことも同じなので,コードの書きかたも非常に似ているし,短いコードで済む具合も似ている.FlaxではテンソルもNHWCの順序でデータを格納するようになっており,数式的な配置でない順序が多かったライブラリよりも直感的である(PyTorchはNCHWの順序).numpyよりjuliaの方が数式どおりに,線形代数をプログラミングできて直感的であるのと同じ話である.現代的な数値計算ライブラリ・フレームワークほど,数学の式に寄せている.

Flaxは関数型言語の設計:おかげで柔軟性が高いが,ハードルもやや高い

Flaxは(純粋)関数型言語の方針で設計されており,おかげでかなり柔軟性の高いコーディングが可能である.一方で,従来から人気のPyTorchやKerasなどは「オブジェクト指向型」のフレームワークであった.よって,関数型言語に慣れていないまま,いきなりFlax・JAXへ移行するには,関数型プログラミングが苦手な人だと,慣れや勉強が必要な側面もある (Flaxのドキュメントのhow do I …?のページに「flaxで…を実装したいときにはどうすれば良いのか」の解説ページリンクが,一覧化されてあり参考になる).

もちろん,部品(既存の層)を組み合わせて済む程度のネットワークをつくるだけなら,書きかたの流儀が違ってもJaXはたいして難しくない.しかし,ローレベルの処理から,JAX系フレームワーク(FlaxやTraxなど)で書こうとすると,関数型コーディングに慣れていないと難しい場面も出てくる.

また,従来「PyTorch → TensorFlow」間でのコードだと,ポーティングは比較的楽にできるのに対して,「PyTorchやHuggingFaceのコード → Flax/JAXのコード」のポーティングは,指向性が異なるので簡単ではなく,当然Flaxで用意されている各部品 (層やOptimizer)への固有の知識も必要となる.(FlaxドキュメントのConvert PyTorch Models to Flaxが参考になる).

このように,関数型設計でハードルが高い面もあるので,管理人としても,FlaxへのJAX系フレームワークへの移行を全員にお薦めするわけではない.「色々な研究で見かけるようになってきたので,存在は知っておくとよい段階になった」ので,とりあえず概要だけ書いた,としておきたい.

なぜGoogle内で2種類つくられている?

「コードが似ているなら,なぜ同じGoogle社内で,HaikuとFlaxを別々に作る意味があるの? 」という疑問もでるのが当然である.DeepMind は少数精鋭部隊であり,本社と同じマウンテンビュー所属の会社であるものの,Andrew Zisserman のVGGラボ(英 Oxford大)も提携相手なのでしてやりとりしている.よって,DeepMindは,外から見るとGoogle内でも独立部隊感が強い.そういう意味で,DeepMindのHaikuと,Google 全体のFlaskが2個あるは仕方ない.ただし「Brainグループも,Flaxの開発には関わってる」とgithubレポジトリには書いているので,結局実際の社内の関係は我々にはよくわかならい.

以上,2者の社内位置づけを考えると,Google本体が作る Flaxは,エンジニア・デプロイ志向がありそうで,Haikuは「(DeepMind的な)基礎研究のモデルや問題を提案するリサーチャー」やそうでなくても応用的なリサーチャー全般が使いやすいと,とりあえず考えておいたら良いとは思う.とはいいながら,両者のドキュメントを眺めていただくとわかるが,2者のそうした違いは,あまりライブラリ設計からは読み取れない.

いずれにせよ,自社のTPUにも特化したJAX系のライブラリを,深層学習向けにも構築・提供することで,Googleの自社クラウドサービスや,Colaboratoryなどの利用を促進させることが,彼らの共通のミッションであると思う.いくつかあるJAXベースの深層学習ライブラリのうち,そこそこコミュニティが大きく育っていったものから将来的に集中的に投資される,という感じのスタンスであろう.

1.3 その他の汎用DLライブラリ

  • Apache MXNet:オープンソースソフトウェアの団体Apacheから提供されている,汎用深層学習ライブラリ
  • CNTK ( from microsoft)

2 コンピュータビジョン向けライブラリ

2.1 多種タスクのコードベース/Example集

  • コンピュータビジョン・ディープラーニング全般のライブラリ
    • OpenMMLab:SenseTimeやその創立元である香港中文大学(CUHK)の研究者・学生が中心として開発されている,PyTorchベースの深層学習ライブラリ.
      • MM=Multimediaであるのは,SenseTimeの創業を行ったXiaoou Tangの研究室MMLabが,コンピュータビジョンだけでなく,画像映像に対するマルチメディア系の研究(画像検索・映像検索・映像要約など)も含めて研究していて広義で「マルチメディアラボ」という名前の研究室であるため.
      • OpenMMLabは,(主にコンピュータビジョンの)タスクごとに,細かくサブフレームワークに分離されている(代表的ものは3.2節で紹介).特に物体検出向けの MMDetection が有名だったが,2020年以降は,他のタスクの MMSegmentation や MMAction2 なども充実してきている印象.
      • 全体像や使用法は CVPR2021のチュートリアルが参考になる.
      • 一方で,Mask R-CNN の提案以降,Facebookラボが得意とするインスタンスセグメンテーションに関しては,OpenMMLab は成長できていない印象.
    • GluonCV:MXNetを使用したコンピュータビジョン用のフレームワーク.
      • 画像認識/物体検出/セグメンテーション/人物姿勢推定/アクション認識など,代表的な問題設定むけのDeepネットワークを,少ない行数で構築することができる.
      • Model Zoo を見ると,用意されている学習済みネットワーク一覧を確認できる
    • NVIDIAのDeep Learning Examples

2.2 タスク別のフレームワーク(ツールキット)

2.3 Transformer系

ライブラリではなく個別のソフトウェアが多いが,このページに一緒に整理しておくと皆様は便利だと思うのでまとめる:

3. NLP 向けライブラリ

  • AllenNLP:Allen Instituteが提供するDeep Learning手法ベースの自然言語処理ライブラリ.
  • Huggingface🤗のTranformers:Transformersを仕様した各種の有名モデルを提供する,自然言語処理を中心としたタスク向けのフレームワーク.PyTorch(主), Tensforflow, JAXに対応.
    • Transformerを用いた最新のNLPモデルを,誰でも簡単に使うことができる.BERTやGPT3などの「事前学習モデル」がすぐ使えるのが,他のフレームワークと比べると便利.
    • supported framework に,モデルごとにどのライブラリで対応されている.主対応先のPyTorchでは,基本的に全モデルがサポートされている.また,Tensorflow, JAX 向けにも,サポートされているモデルが増えてきている.
    • Huggingface Transformerの初心者は,公式サイトが提供するコース (Pytorch/Tensorflow)をまず参照のこと.登場の経緯については,TechBlitzでの創業者インタビュー が参考になる
    • (2022年3月追記) NLP以外でもTransformerモデルが多く提供されている.
      • 初期はNLPモデル中心であったが,最近はViTやDeiTDETRSegFormerなどのコンピュータビジョン向けTransformerモデルも提供されるように.
      • Wav2Vec2,Speech2TextなどのTTSやSSTなど音声-言語間変換のTransformer系モデルも提供.
    • (2022年5月追記) 日本語BERT の例:
      • HuggingFace Transformersに,「日本語で事前学習したBERT」が集結しつつある.
      • 例1 ) 乾研の BERT
      • 例2 ) 黒橋研 BERT-japanese(ただし,古くなっているので以下のROBERTa-japanese使用を推奨している)
      • 例3 )河原研 ROBERTa-baseROBERTa-LARGE
  • spaCy: 実応用向けの,Pythonによる総合NLPツールキット.
    • 最先端の研究者やその成果を使いたい人はHuggingfaceを使うようになってきているが,spaCyは様々なタスクの学習済みモデルが収録されており使い勝手が良い
    • v2.3 から日本語モデルも正式に追加され,日本語でのNLPにも使える.更に, v3.0からTransformerモデルも導入された.
    • また,64以上の言語向けで各種NLPモデルが収録されており,Pytorch/Tensorflowで独自モデルも組める.
  • Stanza (github) : Stanford NLP Groupによる,PyTorchベースの自然言語処理ライブラリ.
    • 66種類の言語向けに「多言語対応」している
    • Stanford としては,NLTKの後継「多数言語向け」ライブラリである.
    • spacyでのラッパーを使うと,spacyからもStanzaの各機能を使える.

4. その他 ライブラリ(用途別)