SPLADE モデルの作り方・日本語SPLADEテクニカルレポート
近年、大規模言語モデル(LLM)の台頭により、情報検索の重要性が増している。特に、Retrieval-Augmented Generation(RAG)などの応用分野では、効率的で高精度な検索システムが求められている。
ニューラルネットワークを用いた検索モデルの分野では、密ベクトルモデル(dense retriever)が主流となっており、multilingual-e5 や bge-m3 のようなマルチリンガル対応の高性能モデルも登場している。一方で、SPLADE(Sparse Lexical and Expansion Model with Contextualized Embeddings)に代表されるスパース検索モデルも、英語圏において高い性能を示している。
しかし、SPLADE は単語の特徴量に大きく依存し、そのトークン化がモデルのトークナイザに左右されるため、マルチリンガル対応版が存在していなかった。マルチリンガルモデルのトークナイザでは多くの言語で1文字単位の分割が行われ、意味のある単語単位でのトークン化が困難であったためである。そこで日本語に特化したSPLADEモデルを開発し、その評価を行った。
さらに、元のSPLADE実装(naver/splade)がCC-BY-NCライセンスで提供されており商用利用に制限があることから、論文を基にTrainerを実装し、MITライセンスのオープンソースソフトウェアとして公開した。
本レポートでは、日本語SPLADEモデルの実装詳細、評価実験の結果、および今後の展望について報告する。
SPLADEのアルゴリズム
SPLADE は、情報検索においてスパースな文書およびクエリ表現を学習するためのモデルである。本節では、SPLADE がどのように学習されるか、そのアルゴリズムについて記述する。
単語重要度の計算と単語ごとの出力トークンの利用
SPLADE は、Masked Language Modeling(MLM)などで事前学習されたモデルの各単語ごとの出力トークンを利用し、文脈に応じた単語の重要度を計算する。具体的には、BERT のような事前学習モデルの語彙空間を活用し、入力シーケンスの各位置で得られた単語のスコアから最大値を選択する max pooling を用いる。また、対数飽和関数を適用することで、極端な値を抑制しつつ重要な特徴を強調することが可能である。これらの手法により、顕著な特徴を捉えたスパースで効率的な文書およびクエリの表現を生成する。
なお、これらの操作は SPLADE Max と呼ばれるもので、Python での実装を以下に示す。
def splade_max_pooling(logits, attention_mask):
# Step 1: 対数飽和関数の適用 (log(1 + x))
# - torch.relu() で負の値を0にする
# - torch.log(1 + x)で値を対数スケールに変換し、大きな値を抑制
relu_log = torch.log(1 + torch.relu(logits))
# Step 2: attention_maskを使って、padding された位置のスコアを0にマスク
# unsqueeze(-1)で次元を合わせる(batch_size, seq_len, 1)
weighted_log = relu_log * attention_mask.unsqueeze(-1)
# Step 3: max pooling の適用
# torch.max()で系列長方向(dim=1)の最大値を取得
# 各語彙に対する最も重要なスコアを選択する
max_val, _ = torch.max(weighted_log, dim=1)
return max_val
単語重要度を用いたドキュメントとクエリの予測
SPLADE Max を通じて得られた単語重要度を活用し、ドキュメントとクエリの関連度を予測する。関連度は主に内積を利用する。この予測結果と元の学習データとの間の差異を、損失関数として定義する。
この損失関数には、モデルが予測した語彙分布と実際の語彙分布との間の差異を測定するために、KLダイバージェンス損失、MarginMSE損失、クロスエントロピー損失等を用いる。これらの損失関数は単体で用いても、複数を組み合わせても良い。SPLADE-v3では、KLダイバージェンス損失とMargineMSE損失を組み合わせて使っている。
スパース性の導入と正則化
出力される単語重要度にスパース性を持たせるため、正則化手法を損失関数に組み込む。具体的には、以下のアルゴリズムが使用される。
-
L1正則化:モデルのパラメータの絶対値の総和を最小化することで、多くのパラメータをゼロに近づける。この手法により、重要でない単語の影響を排除し、スパースな表現を促進する。
-
FLOPs正則化:高次元でスパースな表現学習において、非ゼロ要素を次元間で均一に分散させることで行列演算の計算量(FLOPs)を二次的に削減する正則化手法。(Minimizing FLOPs to Learn Efficient Sparse Representations)
なお、クエリとドキュメントでは異なる損失関数や正則化係数を適用することが可能である。また、学習の初期段階から強い正則化を適用すると、重要度予測に悪影響を及ぼす可能性がある。そのため、適用まで緩やかなウォームアップ期間を設けて、正則化損失の重みを徐々に高めていく手法も取り入れている。
モデルの学習と関連度の計算
これらの手法を組み込んで学習を行うことで、スパース性を促進させつつ、クエリとドキュメントの関連度を高めるモデルを構築できる。SPLADE は、スパースな表現とニューラルネットワークの文脈を含めた語彙情報を組み合わせることで、高性能な情報検索が実現できる。
日本語モデルでの学習手法
データセットの準備
最終的に学習したモデル japanese-splade-base-v1 の学習用データセットとして、日本語の様々な質問文と回答、ハードネガティブを集めたhpprc/embのうち、auto-wiki-qa、mmarco、jsquad、jaquad、auto-wiki-qa-nemotron、quiz-works、quiz-no-mori、miracl、jqara、mr-tydi、baobab-wiki-retrieval、mkqa を利用した。また、hpprc/embのデータに対して日本語高性能なクロスエンコーダーを用いたリランカー(BAAI/bge-reranker-v2-m3、cl-nagoya/ruri-reranker-large)を使用し、スコア付けを行ったデータセットも作成した(hotchpotch/hpprc_emb-scores)。さらに、英語データセットとしてMS MARCOと、そのデータに BAAI/bge-reranker-v2-m3 でスコア付けしたデータを利用した。
データのフィルタリングにおいては、各リランカーの平均スコアを用いて、正例に対してはスコアが0.7以上、負例に対しては0.3以下のデータを選別した。これは、質問に対して適切なスコアでないとリランカーが判断した文章を除外するためである。
データセットの割合が少ないものについては、1エポックあたりの学習量を増加させた。これは、そのデータセットの特性をモデルが忘れないようにするためである。
また、mmacro(日本語)のみを学習させるデータセットとして、mmacroとBAAI/bge-reranker-v2-m3でスコア付けしたデータセットhotchpotch/mmarco-hard-negatives-reranker-scoreを作成し、利用した。このデータもリランカーの平均スコアを用いて、同様に正例に対してはスコアが0.7以上、負例に対しては0.3以下のデータを選別した。
学習の設定とハイパーパラメータ
学習における損失関数として、単純なクロスエントロピー損失を採用した。これは、高性能なリランカーから得られたスコアをモデルが学習できるようにするためである。他にもKLダイバージェンス損失やMarginMSE損失を試したが、クロスエントロピー損失が最良の結果を示した。
スパース性を促進する正則化項には、L1正則化を使用した。これは、FLOPs損失と比較した際、日本語においてL1正則化の方がスパース性の促進効果が高かったためである。
ハイパーパラメータとして、学習率(Learning Rate, LR)は一般的な110Mパラメータのモデルで用いられる5.0e-2
を設定した。学習率のスケジューラにはコサインスケジューラを採用し、全体の10%をウォームアップ期間として設定した。
また1つのバッチでは正例1つ・負例7つ、合計8つのデータを含めている。バッチサイズは、japanese-splade-base-v1 が 32、japanese-splade-base-v1-mmarco-only が 128である。これは、mmacroのみの場合はクエリと文章のスパース性の収束が大きなバッチサイズでも早く、多様なデータセットを学習しているjapanese-splade-base-v1ではバッチサイズが大きいとスパース性の収束が遅くなるため、小さいバッチサイズの方が適していたためである。なお、学習時間やリソースに余裕があるなら、japanese-splade-base-v1も大きいバッチサイズの方が良い結果になる可能性がある。
その他、詳細なパラメータは、実際の学習に使った設定ファイルを参考にされたし。
ノイズトークンの除去
日本語での学習において、、
。
「
:
などの句読点や記号のトークンがノイズとして顕著に特徴量に現れることが確認された。これらのトークンがSPLADE Maxの出力に残存する場合、ペナルティとしてそのトークンのスコアを損失関数に追加している。また、これらのトークンはfugashiとunidic-liteを用いて、記号的な単語と判定できるものを抽出した。
これらをノイズトークンとして扱い、損失に組み込むことで、学習済みモデルの出力においてこれらのノイズトークンはほぼ出力されなくなった。また、学習の安定性が向上し、収束速度の速さも観測された。
学習元モデルの選択
今回、学習元となるモデルには、MLM(Masked Language Modeling)による事前学習で獲得した語彙の意味的特徴量を出力層に持つtohoku-nlp/bert-base-japanese-v3を利用した。このモデルは日本語BERTアーキテクチャをベースとしている。
学習
これらを基に、japanese-splade-base-v1とjapanese-splade-base-v1-mmarco-onlyモデルをファインチューニングし作成した。学習にかかった時間はGPU RTX4090環境で、japanese-splade-base-v1が約33時間、japanese-splade-base-v1-mmarco-onlyが約24時間である。
また、japanese-splade-base-v1においてはデータセットサイズが大きいため2エポック、japanese-splade-base-v1-mmarco-onlyにおいてはデータセットはmmacroのみとデータセットが小さいため12エポック学習した。なお、japanese-splade-base-v1の学習エポックを増やすと過学習になるためか、トレイン損失値は下がるが、評価時の検索タスクにおいて性能低下が確認された。
なお、学習したモデルは HuggingFace で公開している。
- https://huggingface.co/hotchpotch/japanese-splade-base-v1
- https://huggingface.co/hotchpotch/japanese-splade-base-v1-mmarco-only
評価結果
JMTEB retrieval タスクでの評価結果
JMTEBでの評価結果は以下の通りである。なお、実際の評価にはスパースベクトルを評価できるように変更したfork版を利用している。
model_name | Avg. | jagovfaqs | jaqket | mrtydi | nlp_journal abs_intro |
nlp_journal title_abs |
nlp_journal title_intro |
---|---|---|---|---|---|---|---|
japanese-splade-base-v1 | 0.7465 | 0.6499 | 0.6992 | 0.4365 | 0.8967 | 0.9766 | 0.8203 |
japanese-splade-base-v1-mmarco-only | 0.7313 | 0.6513 | 0.6518 | 0.4467 | 0.8893 | 0.9736 | 0.7751 |
text-embedding-3-large | 0.7448 | 0.7241 | 0.4821 | 0.3488 | 0.9933 | 0.9655 | 0.9547 |
GLuCoSE-base-ja-v2 | 0.7336 | 0.6979 | 0.6729 | 0.4186 | 0.9029 | 0.9511 | 0.7580 |
multilingual-e5-large | 0.7098 | 0.7030 | 0.5878 | 0.4363 | 0.8600 | 0.9470 | 0.7248 |
multilingual-e5-small | 0.6727 | 0.6411 | 0.4997 | 0.3605 | 0.8521 | 0.9526 | 0.7299 |
ruri-large | 0.7302 | 0.7668 | 0.6174 | 0.3803 | 0.8712 | 0.9658 | 0.7797 |
結果の平均としては、japanese-splade-base-v1 が mrtydi や JAQKET のドメインタスクを学習(JMTEB の評価で使うテストデータではない)しているが、japanese-splade-base-v1 が最良の結果となった。また、japanese-splade-base-v1-mmarco-only は mmacro データセットしか学習させていないが、mrtydiでは最良の結果となり、他のタスクも他のモデルと十分競争力がある結果となった。
jagovfaqs の結果は、SPLADE モデルが他のモデルに比べて軒並み悪い。これは jagovfaqs のクエリの内容が「FAQ」であり、要約・文脈類似タスクに似た問題が多く含まれることが考えられる。他のモデルは文章の意味的類似度を学習しており、japanese-splade-base-v1 は学習していない。また、スコアが高い日本語モデルのruri-largeやGLuCoSE-base-ja-v2では、マルチリンガルFAQ(Frequently Asked Questions) & CQA(Community Question Answering)データセットのの MQAの日本語データを学習していることも、スコア向上に寄与している可能性がある。
jaqket の結果は、「クイズ形式」の質問が多く含まれる。「XXXといえばYYYですが、ZZZといえば何でしょう?」のような日本語クイズ独特の言い回しを含んでおり、それらの表現を学習しているモデルが高スコアになる。また、正解の文章内部に正解の単語を必ず含むため、単語特徴量に強い SPLADE が高スコアにつながると考えられる。
mrtydi の結果は、mrtydi のドメインを学習しているはずのjapanese-splade-base-v1が、ドメインを学習していないjapanese-splade-base-v1-mmarco-onlyよりも悪いという、直感に反する結果となった。これについては十分な考察ができていない。
nlp_journal の三つのタスクにおいては、title_abs においては、SPLADEモデルが軒並み高性能だが、abs_intro、title_introにおいては text-embedding-3-large が圧倒的に高性能である。これはtitle_absの文章の平均長が442で、abs_intro、title_introは2052のためである。text-embedding-3-large 以外はモデルのトークンの最大長が全て512であり、text-embedding-3-largeは8191である。そのため、text-embedding-3-large 以外のモデルはabs_intro、title_introの文章全体を処理することができず、文章の冒頭一部のみでの評価になるため、長いトークン長を理解可能なモデルが高いスコアとなる。
reranking タスクでの評価結果
reranking タスクの評価には JQaRA、JaCWIR を用いた。
model_name | JaCWIR map@10 | JaCWIR HR@10 | JQaRA ndcg@10 | JQaRA mrr@10 |
---|---|---|---|---|
japanese-splade-base-v1 | 0.9122 | 0.9854 | 0.6441 | 0.8616 |
japanese-splade-base-v1-mmarco-only | 0.8953 | 0.9746 | 0.5740 | 0.8176 |
text-embedding-3-small | 0.8168 | 0.9506 | 0.3881 | 0.6107 |
GLuCoSE-base-ja-v2 | 0.8567 | 0.9676 | 0.6060 | 0.8359 |
bge-m3+dense | 0.8642 | 0.9684 | 0.5390 | 0.7854 |
multilingual-e5-large | 0.8759 | 0.9726 | 0.5540 | 0.7988 |
multilingual-e5-small | 0.8690 | 0.9700 | 0.4917 | 0.7291 |
ruri-large | 0.8291 | 0.9594 | 0.6287 | 0.8418 |
結果としては、JQaRA のドメインを学習しているとはいえ、japanese-splade-base-v1 がどれも最良の結果となった。
英語タスクでの評価
japanese-splade-base-v1は、MS MARCOの英語データセットも学習データセットに含めた。そのため、naver/spladeで公開されている評価スクリプトを用い、MS MARCO(dev)で評価した。
model_name | MRR@10 (MS MARCO dev) |
---|---|
japanese-splade-base-v1 | 0.047 |
japanese-splade-base-v1-mmarco-only | 0.036 |
naver/splade_v2_max | 0.340 |
結果として、英語データを学習していないjapanese-splade-base-v1-mmarco-onlyよりも、わずかながらスコア向上が見られるが、英語のみを学習している naver/splade_v2_max と比べると著しくスコアが低く、英語における検索性能はほとんどないと言える。
スパース性の評価
スパース性の評価では、非ゼロ要素の数(L0ノルム)を用いて、各モデルのクエリおよび文書のスパース性を測定した。以下に、JMTEBのretrieveタスク(Top-1000)における japanese-splade-base-v1 および japanese-splade-base-v1-mmarco-only モデルのクエリおよび文書のスパース性の結果を示す。
なお、この結果は JMTEB_L0.py で計測した。
JMTEB tasks | v1 | v1-mmarco-only |
---|---|---|
jagovfaqs_22k-query | 27.9 | 43.4 |
jaqket-query | 23.3 | 38.9 |
mrtydi-query | 13.8 | 20.5 |
nlp_journal_abs_intro-query | 75.3 | 127.2 |
nlp_journal_title_abs-query | 19 | 26.4 |
nlp_journal_title_intro-query | 19 | 26.4 |
jagovfaqs_22k-docs | 73.2 | 97.9 |
jaqket-docs | 146.2 | 231.8 |
mrtydi-docs | 89.3 | 100.4 |
nlp_journal_abs_intro-docs | 95.7 | 182 |
nlp_journal_title_abs-docs | 75.2 | 126.9 |
nlp_journal_title_intro-docs | 95.7 | 182 |
L0ノルムの値から、v1-mmarco-onlyの方が全体的に非ゼロ要素が多く、スパース性が低いことが示されている。クエリと文書のスパース性の度合いは、検索システムのパフォーマンスに対する重要な要素とされるが、クエリと文書には異なる要件がある。
検索速度を考慮する場合、クエリのスパース性が高いほど効率的な検索が期待できるが、文書のスパース性もまた省メモリや省ディスクの観点で重要である。ただし、実運用環境では数百万〜数千万規模の文書が1台のマシンでもオンメモリで検索可能な場合が多いため、文書のスパース性についてはクエリほど厳格に管理する必要はないと考えられる。
一方、クエリのスパース性は検索速度に直接関係するため、できる限り高いスパース性が求められる。ただし、文書のスパース性に関しても、非ゼロ要素が少なすぎると検索性能に悪影響を及ぼす可能性があるため、適切なバランスが求められる。検索システムの性能と効率の両立を目指す上で、クエリと文書のスパース性を考慮したチューニングが重要である。
評価の考察まとめ
これらの結果から、japanese-splade-base-v1 は日本語データの検索において、最新のモデルと十分競争力があるモデルと言える。とりわけ、単語特徴量が重要と思われるタスクでは優れた性能を発揮する。クエリや文章のスパース性能も、必要十分と言えよう。
また、他のモデルは密ベクトルモデルであるが、SPLADE はスパースベクトルモデルであり、単語特徴量を重視する検索結果になるため、密ベクトルモデルのみを利用するより、異なるモデルを組み合わせることで多様性のある検索結果を得ることができる。これは、実世界で多様な検索結果を取得したいというケース、例えばLLMにさまざまな検索情報を渡すなど、で重要になるだろう。
今後の展望
一旦、japanese-splade-base-v1 を成果物として公開したが、まだ性能向上の余地は多い。SPLADEの元論文では、自己蒸留(self distillation)や複数の損失スコアの利用、SPLADEモデル自体を使ったハードネガティブサンプリングなどを行うことにより、性能向上が図られている。
また、検索タスクに適した事前学習モデルの選択・学習なども行えていない。例えば、Unsupervised Corpus Aware Language Model Pre-training for Dense Passage RetrievalやRetroMAE: Pre-Training Retrieval-oriented Language Models Via Masked Auto-Encoder等、検索タスクに適した事前学習を行うことで、性能向上の可能性がある。
他にも、FAQ系のタスクのデータセットの学習やロングコンテキストへの対応、多様なデータセット(現状ではWikipediaに偏りがち)の追加等が考えられる。
近年、Llama 3.1 をはじめとする、LLM の出力を学習に利用可能なライセンスを持つモデルが登場し、ライセンス上の問題なく検索用データセットを作成できるようになってきた。本モデルでも利用した hpprc/emb では、LLM の出力を活用した高品質なデータセットを提供している(Ruri: Japanese General Text Embeddings)。
従来、ドキュメントから情報検索に適したクエリを作成することは人手がかかり大変であったが、LLM を用いて自動的に生成することで、低コストで大量のクエリを作成できるようになった。特定のドメインを学習することで一般化性能が向上する場合が多く、情報検索モデルの学習用データセットが充実することで、さらなる性能向上が期待できる。
おわりに
本レポートでは、日本語に特化したSPLADEモデルであるjapanese-splade-base-v1を開発し、その評価を行った。評価結果から、日本語の情報検索において、既存の最新モデルと比較しても高い性能を示すことが確認できた。
今後の課題として、さらなる性能向上のための手法の検討や、検索タスクに適した事前学習モデルの選択、多様なデータセットの活用が挙げられる。
日本語SPLADEモデルとSPLADEモデル学習用Trainerの公開により、情報検索技術の発展に寄与できれば幸いである。
参考文献
- SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
- SPLADE v2: Sparse Lexical and Expansion Model for Information Retrieval
- From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective
- An Efficiency Study for SPLADE Models
- A Static Pruning Study on Sparse Neural Retrievers
- SPLADE-v3: New baselines for SPLADE
- Minimizing FLOPs to Learn Efficient Sparse Representations
- Ruri: Japanese General Text Embeddings
- JaColBERTv2.5: Optimising Multi-Vector Retrievers to Create State-of-the-Art Japanese Retrievers with Constrained Resources
- 日本語テキスト埋め込みベンチマークJMTEBの構築
- mMARCO: A Multilingual Version of the MS MARCO Passage Ranking Dataset
- Mr. TyDi: A Multi-lingual Benchmark for Dense Retrieval
- JaCWIR: Japanese Casual Web IR - 日本語情報検索評価のための小規模でカジュアルなWebタイトルと概要のデータセット
- JQaRA : Japanese Question Answering with Retrieval Augmentation - 検索拡張(RAG)評価のための日本語 Q&A データセット
- JAQKET: クイズを題材にした日本語 QA データセットの構築
- 高性能な日本語SPLADE(スパース検索)モデルを公開しました - A Day in the Life
@article{tateno2024splade,
title={SPLADE モデルの作り方・日本語SPLADEテクニカルレポート},
author={TatenoYuichi},
year={2024},
url={https://secon.dev/entry/2024/10/23/080000-japanese-splade-tech-report/}
}