このページでは,近年 (2021年 現在)よく使用されている代表的な Deep Learning 向けのライブラリ・フレームワークをカテゴリごとに整理する.
公的な研究・論文向けの,無償で使用できるものを中心に一覧化してある.よって,商用使用を目的とするライブラリや,クラウドデプロイ目的の商用サービス(例えばAmazon SageMakerなど)は,このページでは基本リストアップしていない.
※ 便宜上この記事では,「ライブラリ」と「フレームワーク」の2つを同じ意味であるとみなし,両者を区別せずに「ライブラリ」とまとめて呼びたい.この記事で登場するものは,ライブラリなのかフレームワークなのかの境界線が怪しいものが多く,いちいち「フレームワーク」と呼び分けるのも面倒なのが理由である.また,「ツールキット」もその中に含める.
目次
1. 代表的な 汎用Deep Learning Library
1.1 2大 DLライブラリ
汎用的な,特定のタスクやライブラリに特化していないDeep Learning Libraryでは,PyTorch と TensorFlow(+ Keras)が,特に研究開発向けでは2強であった(2021年ごろまで).
- PyTorch 【PyTorchの紹介記事】
- Pytorch-Lightning:軽量・高速化されたPyTorchラッパー:
- Catalyst:軽量・高速化されたPytorchラッパー.学習ループを各関数で分離分割管理でき,関数型プログラミング的に短く書ける.
- Pytorch Geometric:GraphNNや,測地線,メッシュ幾何など幾何系処理全般むけ.
- TensorFlow
- Keras:コーディングに癖があるTensorflow本体と違い,Kerasはシンプルなインターフェースであるので, PyTorchのようにさくさくコードを書ける.
- Tensorflow Lite :エッジ計算・モバイル向け.
- Tensorflow.js:ブラウザ向け.
それに加えて MXNet 勢,というのが,2021年前半の状況だった
1.2 第3の定番DLライブラリの登場: Jaxを元にしたDLライブラリ
2021年ごろから「JAX (1.2.1節)を用いた深層学習ライブラリ」が,PyTorch Tensorflow(+Keras)につづく,第3の定番DLライブラリの主軸候補として,普及しはじめている (1.2.2節)
ViTやMLP-mixerの,Googleによる公式コードgithubレポジトリはJAX/Flaxベースであるなど,最先端研究者だと,JAXも触る必要性が少しずつ出てきている.
1.2.1 JAXについて:改良版AutogradとXLA(機械学習特化のJITコンパイラー)
JAXは,深層学習のハイパフォーマンス化を特に意識した,Googleが中心で開発されている機械学習向けの数値計算ライブラリである.JAX は「改良版Autograd」と,「XLA(機械学習向けのコンパイル最適化)」の2つから構成さており,シンプルな短いコードで済みながら,高性能な機械学習向けソフトウェアを開発できる.
Google独自の改良版autogradを使用しているので,先行のPyTorchなどに含まれる(通常版の)autogradよりも,自動微分の性能が良くなっており,深層学習に力を発揮する.
また,XLAでは,numpyを含むコードを,Just-in-time コンパイルしてくれるので「関数単位の高性能化・省メモリ化」をXLAにお任せできる.具体的には,XLAは,各関数を,簡単な記述を追加するだけで「自動並列化・自動ベクトル化・JITコンパイル」を行うことができ,ソフトウェアのパフォマーンスを向上させてくれる(ドキュメントのQuick Startが,初めての人には参考になる.)
よって,JAXにおいても,XLAを使用して,コーディングの負担が少ないまま,機械学習ソフトウェアの高性能化が達成できる(※ XLAと類似するものに,Anacondaの「numba」によるnumpy計算のJITコンパイルがある.numbaでも@jitと各関数に書くだけでJITコンパイルが実行できる)
JAXについてより詳しくは,公式githubの「What is JAX?」の節を参照のこと.
1.2.2 JAXベースのシンプルなDLライブラリ:HaikuとFlax
2021以降,研究者中心に,JAXベースのDeep Learningライブラリを使ったコードをよく見るようなり,それまでの2大定番(Pytorch Tensfor flow)と合わせて,第3の定番と化しつつある.
具体的は,以下の,Haiku,FlaxやTraxなどが,研究発表でも公開コードで見る機会が増えてきた(特にFlaxが普及しつつある):
- Haiku(Google DeepMind):
- DeepMindの Sonnet(Tensorflow2ベース)を,JAXベースのライブラリにしてパワーアップさせたものが,このHaikuである.
- よって,DeepMindチームの先端的な基礎研究成果は,最近だとこのHaikuでコードが公開されている.
- Flax (Google):
- Haikuとは別に,マウンテンビューのGoogle本体やResearchチームによって作られているのがFlaxである.
- 例えばViT(Vision Transformer)の公式コードは,JAX/Flaxである.つまり,ViTやMLP-Mixerなどの公式コードを研究で扱うには,まずFlaxなのである(とはいえ,その後,Pytorch/visionやTensforflow, huggingface版も用意されている)
- 一方で,DETRは,facebookから提案されたのでpytorchであるなど,google以外のTransformer系の先端研究は,まだJAXベースではない.あくまで,Google界隈でのJAX(+Flax)の使用が増えている段階である(※ 2022年3月執筆).
JAXベースのこれらのライブラリの方が,シンプルで部品的(modular)なコードで,高速かつスケールする深層学習を行える.Haikuは特にそうだが,Tensorflowベースなので,元々Tensorflowだったものを簡単にHaikuのコードへ移行できる(= Google同士なので,そうなるように設計されている).
HaikuもFlaxも,両者は「コード行数を増やさず,短いコードで機械学習の高速化・スケール化ができる」という目的のJAXが元になっている上に,以前から存在するPytorch lightningなどの「軽量でシンプル化して,なおかつ部品的に書ける」ライブラリの理念もアップデートしている.
HaikuとFlaxは,オブジェクト名こそ結構違うが(※ 全結合層の名前は,HaikuはDense()で,FlaxはLinear()),共にJAX基盤で,なおかつ達成したことも同じなので,コードの書きかたも非常に似ているし,短いコードで済む具合も似ている.FlaxではテンソルもNHWCの順序でデータを格納するようになっており,数式的な配置でない順序が多かったライブラリよりも直感的である(PyTorchはNCHWの順序).(※ numpyよりjuliaの方が数式どおりに,線形代数をプログラミングできて直感的であるのと同じ話である.現代的な数値計算ライブラリ・フレームワークほど,数学の式に寄せている).
Flaxは関数型言語の設計:おかげで柔軟性が高いが,ハードルもやや高い
Flaxは(純粋)関数型言語の方針で設計されており,おかげでかなり柔軟性の高いコーディングが可能である.一方で,従来から人気のPyTorchやKerasなどは「オブジェクト指向型」のフレームワークであった.よって,関数型言語に慣れていないまま,いきなりFlax・JAXへ移行するには,関数型プログラミングが苦手な人だと,慣れや勉強が必要な側面もある (Flaxのドキュメントのhow do I …?のページに「flaxで…を実装したいときにはどうすれば良いのか」の解説ページリンクが,一覧化されてあり参考になる).
もちろん,部品(既存の層)を組み合わせて済む程度のネットワークをつくるだけなら,書きかたの流儀が違ってもたいして難しくない.しかし,ローレベルの処理から,Flax/JAXで書こうとすると,関数型コーディングに慣れていないと難しい場面も出てくる.
また,従来「PyTorch → TensorFlow」間でのコードだと,ポーティングは比較的楽にできるのに対して,「PyTorchやHuggingFaceのコード → Flax/JAXのコード」のポーティングは,指向性が異なるので簡単ではなく,当然Flaxで用意されている各部品 (層やOptimizer)への固有の知識も必要となる.(FlaxドキュメントのConvert PyTorch Models to Flaxが参考になる).
このように,関数型設計でハードルが高い面もあるので,管理人としても,FlaxへのJAX系フレームワークへの移行を全員にお薦めするわけではない.「色々な研究で見かけるようになってきたので,存在は知っておくとよい段階になった」ので,とりあえず概要だけ書いた,としておきたい.
なぜGoogle内で2種類つくられている?
「コードが似ているなら,なぜ同じGoogle社内で,HaikuとFlaxを別々に作る意味があるの? 」という疑問もでるのが当然である.DeepMindは少数精鋭部隊であり,本社と同じマウンテンビュー所属の会社であるものの,Andrew Zisserman のVGGラボ(英 Oxford大)も提携相手なのでしてやりとりしている.よって,DeepMindは,外から見るとGoogle内でも独立部隊感が強い.そういう意味で,DeepMindのHaikuと,Google 全体のFlaskが2個あるは仕方ない.ただし「Brainグループも,Flaxの開発には関わってる」とgithubレポジトリには書いているので,結局実際の社内の関係は我々にはよくわかならい.
以上,2者の社内位置づけを考えると,Google本体が作る Flaxは,エンジニア・デプロイ志向がありそうで,Haikuは「(DeepMind的な)基礎研究のモデルや問題を提案するリサーチャー」やそうでなくても応用的なリサーチャー全般が使いやすいと,とりあえず考えておいたら良いとは思う.とはいいながら,両者のドキュメントを眺めていただくとわかるが,あまり,2者のそうした違いは,ライブラリ設計からは読み取れない.
いずれにせよ,自社のTPUに特化したJAX系ライブラリを深層学習向けにも構築することで,自社クラウドサービスや,Colaboratoryの利用を促進させることが,彼らのミッションではあると思う.
1.3 その他の汎用DLライブラリ
2. ニューラルネットのアーキテクチャ図の作成
- NN-SVG: 単純な形式のCNNアーキテクチャ図をブラウザ上で作成できる.FCN, LeNet, AlexNetの3種から選べる.
- PlotNeuralNet:LatexコードやPythonで,論文用の精緻なCNNアーキテクチャの図を作成できる.UbuntuもしくはWindows(Cygwin環境)で使える(※著者がMacで使いたくてたまらないソフトウェア).
- conv_arithmetic:畳み込み演算のアニメーション,斜め上カメラ位置から見た図を作れる.
- Tensorspace.js:TensorFlow.js, Three.js and Tween.js. を用いた,Keras的なAPIでコーディングしたニューラルネットワークアーキテクチャを,ブラウザ上で3次元可視化できる.
3. コンピュータビジョン向けのライブラリ
3.1 多種タスクのコードベース/Example集
- コンピュータビジョン・ディープラーニング全般のライブラリ
- OpenMMLab:SenseTimeやその創立元である香港中文大学(CUHK)の研究者・学生が中心として開発されている,PyTorchベースの深層学習ライブラリ.
- MM=Multimediaであるのは,SenseTimeの創業を行ったXiaoou Tangの研究室MMLabが,コンピュータビジョンだけでなく,画像映像に対するマルチメディア系の研究(画像検索・映像検索・映像要約など)も含めて研究していて広義で「マルチメディアラボ」という名前の研究室であるため.
- OpenMMLabは,(主にコンピュータビジョンの)タスクごとに,細かくサブフレームワークに分離されている(代表的ものは3.2節で紹介).特に物体検出向けの MMDetection が有名だったが,2020年以降は,他のタスクの MMSegmentation や MMAction2 なども充実してきている印象.
- 全体像や使用法は CVPR2021のチュートリアルが参考になる.
- 一方で,Mask R-CNN の提案以降,Facebookラボが得意とするインスタンスセグメンテーションに関しては,OpenMMLab は成長できていない印象.
- GluonCV:MXNetを使用したコンピュータビジョン用のフレームワーク.
- 画像認識/物体検出/セグメンテーション/人物姿勢推定/アクション認識など,代表的な問題設定むけのDeepネットワークを,少ない行数で構築することができる.
- Model Zoo を見ると,用意されている学習済みネットワーク一覧を確認できる
- NVIDIAのDeep Learning Examples
- OpenMMLab:SenseTimeやその創立元である香港中文大学(CUHK)の研究者・学生が中心として開発されている,PyTorchベースの深層学習ライブラリ.
3.2 タスク別のフレームワーク(ツールキット)
- 物体検出 ( + インスタンスセグメンテーション・人物姿勢推定):
- Dectron2 (from Facebook Research )
- Dectorn2 Beginner’s tutorial (Google Colab上のチュートリアル)
- MMDetection (from OpenMMLab):
- Dectron2 (from Facebook Research )
- セマンティックセグメンテーション:
- MMSegmentation (from OpenMMLab)
- モバイル向け
- MediaPipe (from google) :スマホカメラ向けのリアルタイム画像認識ライブラリ.
- 人物姿勢推定・人物姿勢追跡
- OpenPose (CMUのpart affinity field)
- OpenPifPaf
- MMPose (from OpenMMLab)
- Media PipeのBlazePose :スマホ向けのリアルタイム人物姿勢追跡.
- アクション認識
- MMAction2 (from OpenMMLab)
- PytorchVideo
- テキスト検出(自然画像からのOCR)
- Vision Transformer(ViT)系:
4. NLP 向けのライブラリ
- AllenNLP:Allen Instituteが提供するDeep Learning手法ベースの自然言語処理ライブラリ.
- Huggingface🤗のTranformers:Transformersを仕様した各種の有名モデルを提供する,自然言語処理を中心としたタスク向けのフレームワーク.PyTorch(主), Tensforflow, JAXに対応.
- Transformerを用いた最新のNLPモデルを,誰でも簡単に使うことができる.BERTやGPT3などの「事前学習モデル」がすぐ使えるのが,他のフレームワークと比べると便利.
- supported framework に,モデルごとにどのライブラリで対応されている.主対応先のPyTorchでは,基本的に全モデルがサポートされている.また,Tensorflow, JAX 向けにも,サポートされているモデルが増えてきている.
- Huggingface Transformerの初心者は,公式サイトが提供するコース (Pytorch/Tensorflow)をまず参照のこと.登場の経緯については,TechBlitzでの創業者インタビュー が参考になる
- (2022年3月追記) NLP以外でもTransformerモデルが多く提供されている.
- 初期はNLPモデル中心であったが,最近はViTやDeiT,DETR,SegFormerなどのコンピュータビジョン向けTransformerモデルも提供されるように.
- Wav2Vec2,Speech2TextなどのTTSやSSTなど音声-言語間変換のTransformer系モデルも提供.
- (2022年5月追記) 日本語BERT の例:
- HuggingFace Transformersに,「日本語で事前学習したBERT」が集結しつつある.
- 例1 ) 乾研の BERT
- 例2 ) 黒橋研 BERT-japanese(ただし,古くなっているので以下のROBERTa-japanese使用を推奨している)
- 例3 )河原研 ROBERTa-baseとROBERTa-LARGE
- spaCy: 実応用向けの,Pythonによる総合NLPツールキット.
- 最先端の研究者やその成果を使いたい人はHuggingfaceを使うようになってきているが,spaCyは様々なタスクの学習済みモデルが収録されており使い勝手が良い
- v2.3 から日本語モデルも正式に追加され,日本語でのNLPにも使える.更に, v3.0からTransformerモデルも導入された.
- また,64以上の言語向けで各種NLPモデルが収録されており,Pytorch/Tensorflowで独自モデルも組める.
- Stanza (github) : Stanford NLP Groupによる,PyTorchベースの自然言語処理ライブラリ.
- 66種類の言語向けに「多言語対応」している
- Stanford としては,NLTKの後継「多数言語向け」ライブラリである.
- spacyでのラッパーを使うと,spacyからもStanzaの各機能を使える.
5. その他 ライブラリ(用途別)
- Graph Neural Networks
- 学習高速化
- 分散学習の簡易化•高速化