日本語 Reranker 作成のテクニカルレポート
本記事は、日本語の reranker (CrossEncoder) モデルを作成における技術レポートである。reranker とは何か、といった内容は別記事 日本語最高性能のRerankerをリリース / そもそも Reranker とは? を参照のこと。
なお今回作ったモデル一覧は以下。
モデル名 | layers | hidden_size |
---|---|---|
hotchpotch/japanese-reranker-cross-encoder-xsmall-v1 | 6 | 384 |
hotchpotch/japanese-reranker-cross-encoder-small-v1 | 12 | 384 |
hotchpotch/japanese-reranker-cross-encoder-base-v1 | 12 | 768 |
hotchpotch/japanese-reranker-cross-encoder-large-v1 | 24 | 1024 |
hotchpotch/japanese-bge-reranker-v2-m3-v1 | 24 | 1024 |
CrossEncoder の学習方法
CrossEncoder は、単純な回帰タスクである。query text[SEP]passage text
といったSEPトークン等で区切ったテキストを、正例は1.0
負例は0.0
としてラベル付けし学習させる。具体的な学習コード例としては、SentenceTransformers の CrossEncoder 学習サンプルが分かりやすい。
また、複数の負例(ハードネガティブ)を正例と同一バッチで学習させることで性能が大きく向上する。この学習方法については、FlagEmbedding の reranker trainerが参考になる。
学習用データセット
学習には、質問と正例・負例のデータセットが必要である。1件につき、positive(正例)1個とhard-negative(負例)15個を1セットとし、1グループ=16個として学習に用いた。以下のデータセットを利用した。
- JQaRA:
dev
/unused
から 7,270件 - JSQuAD:
train
から 62,859件- hard-negative マイニング用に wikipedia の文章から追加
- miracl:
train
の日本語データから 6,984件 - mmarco:
train
の日本語データからフィルターした 346,413件 - mr_tydi:
train
の日本語データから 3,697件- なお miracl の日本語データには、このmr_tydiのデータと重複したデータが多く含まれる
- wikipedia リード文:
- wikipedia のタイトルと、冒頭のリード文をペアとした 40,130件
- hard-negative マイニングでは、同様に wikipedia のリード文のみを対象にマイニング
評価用データセット
モデルの評価には、以下のデータセットを用いた。
- JQaRA:
test
2000件- 評価指標は JQaRA での評価方法として定義されている
NDCG@10
- JSQuAD:
validation
4442件- wikipediaからhard-negativeマイニングで negatives 19件追加し、合計20件からの
MAP@10
で評価
- miracl:
dev
からnegatives
が9件以上のデータでフィルターした、704件positive
1件、negatives
9件の合計10件としてMAP@10
で評価- なお miracl に日本語データでは
dev
とtrain
で一部データが被っており、train
を学習すればするほどdev
の評価が高くなりやすい
- JaCWIR:
eval
5000件- 評価指標は JaCWIR Reranker 評価方法として定義されている
MAP@10
ハードネガティブマイニング
ハードネガティブとは、モデルが正例と誤判断しやすいが実際は負例であるデータを指す。これらを積極的に「マイニング」することで、学習データの多様性と難易度を高め、モデルの精度向上が期待される。
本モデルでは、BM25と複数のSentenceTransformerモデルを用いてハードネガティブをマイニングした。Semantic Textual Similarity(StS)タスクにより、正例に意味的に類似するが実際は負例である文章を抽出した。類似度の高いデータからランダムにサンプリングする方法を採用した。
学習元のpre-trainモデル
以下のpre-trainモデルを学習のpre-trainモデルとして利用した。なお BAAI/bge-reranker-v2-m3
については全件学習させると汎化性能が低下したため、mmarco, JSQuAD, wikipedia リード文を各1万件にランダムサンプリング(他のデータセットは全件)したデータで学習させた。
japanese-reranker-cross-encoder-xsmall-v1
- microsoft mMiniLMv2-L6-H384
- 6 layers, 384 hidden size
japanese-reranker-cross-encoder-small-v1
- microsoft mMiniLMv2-L12-H384
- 12 layers, 384 hidden size
japanese-reranker-cross-encoder-base-v1
- cl-nagoya/sup-simcse-ja-base
- tohoku-nlp/bert-base-japanese-v3
- 二つのモデルで学習させたものの統合モデル
- 12 layers, 768 hidden size
japanese-reranker-cross-encoder-large-v1
- cl-nagoya/sup-simcse-ja-large
- tohoku-nlp/bert-large-japanese-v2
- 2つのモデルで学習させたものの統合モデル
- 24 layers, 1024 hidden size
japanese-bge-reranker-v2-m3-v1
- BAAI/bge-reranker-v2-m3
- 24 layers, 1024 hidden size
過学習への対応
CrossEncoderの学習を進める中で、ハードネガティブにwikipediaの文章を使用しているため、wikipediaデータを利用した関連のタスクの評価(JQaRA, JSQuAD, miracl japaneseなど)には最適化されるが、wikipedia以外のドメインでの汎化性能が学習すればするほど低下することが判明した。そこで、学習データに含まれないドメイン外のデータセットであるJaCWIRを作成し、バランスをとりながら学習・評価を行った。
結果として、1 epoch
以上の学習では過学習が発生したため、学習は1 epoch
のみに制限している。
学習パラメータ
主のモデルの学習には、主に以下のパラメータを使用した。
batch_size
:512
(gradient_accumulation
)- 16個が1グループなので、
pos
,neg
合わせて512 * 16 = 8192
を1バッチで学習
- 16個が1グループなので、
warmup_ratio
:0.25
(全体の25%をwarmupに使用)- スケジューラ:
cosine
- オプティマイザ:
paged_adamw_32bit
learning_rate
:xsmall
=2e-04
small
=5e-04
base
=8e-05
large
=3e-05
- loss
- CrossEntropy
largeモデルを教師モデルとして使用
xsmall
, small
の学習では、japanese-reranker-cross-encoder-large-v1
とjapanese-bge-reranker-v2-m3-v1
の推論出力を教師ラベルとして追加利用した。教師モデルの出力は推論値(例: pos
=0.98
, negs
=[0.02, 0.07, ...]
)となるため、0
と1
だけでなく回帰タスクの連続値としての利用が可能である。教師モデルの出力データを用いることで、大幅ではないが若干のスコア向上が観測された。なお、この学習にはlossはMSEを用いた。
mixモデルの作成
学習データセットやスコアパラメータ、シード値を変更することで、多様な学習結果が得られる。これらの個別に学習したモデルを単純に線形結合することで、多様性を持たせパフォーマンスを向上させることができる。今回、複数の学習済みモデルを結合することでスコアの向上を確認した。なおモデル合成のツールにはLM_Cocktailを利用した。
注意事項としては、合成後のモデルは出力値の標準偏差が小さくなるため、量子化時等になんらかの性能劣化が発生する可能があるかもしれない。
評価結果
作成したCrossEncoderモデルの評価結果は以下の通りである。BAAI/bge-reranker-v2-m3
は元々のマルチリンガル言語に対しての汎化性能が高く初めから日本語に対して高性能で、モデルサイズが問題にならなければ、少量のサンプル(数千件程度)でも微調整可能なため reranker 学習元の微調整モデルとしては最適と考えられる。
なお、この評価データセットのスコアはそのデータセットが公開しているtrainデータ等で学習することでスコアが高く出る傾向にある。今回作ったモデルは、JaCWIR 以外はtrain等のデータで傾向を学習しているため、その点も評価スコアを見る際には留意すると良いであろう。
Model Name | JQaRA | JaCWIR | MIRACL | JSQuAD |
---|---|---|---|---|
japanese-reranker-cross-encoder-xsmall-v1 | 0.6136 | 0.9376 | 0.7411 | 0.9602 |
japanese-reranker-cross-encoder-small-v1 | 0.6247 | 0.939 | 0.7776 | 0.9604 |
japanese-reranker-cross-encoder-base-v1 | 0.6711 | 0.9337 | 0.818 | 0.9708 |
japanese-reranker-cross-encoder-large-v1 | 0.7099 | 0.9364 | 0.8406 | 0.9773 |
japanese-bge-reranker-v2-m3-v1 | 0.6918 | 0.9372 | 0.8423 | 0.9624 |
bge-reranker-v2-m3 | 0.673 | 0.9343 | 0.8374 | 0.9599 |
bge-reranker-large | 0.4718 | 0.7332 | 0.7666 | 0.7081 |
bge-reranker-base | 0.2445 | 0.4905 | 0.6792 | 0.5757 |
cross-encoder-mmarco-mMiniLMv2-L12-H384-v1 | 0.5588 | 0.9211 | 0.7158 | 0.932 |
shioriha-large-reranker | 0.5775 | 0.8458 | 0.8084 | 0.9262 |
bge-m3+all | 0.576 | 0.904 | 0.7926 | 0.9226 |
bge-m3+dense | 0.539 | 0.8642 | 0.7753 | 0.8815 |
bge-m3+colbert | 0.5656 | 0.9064 | 0.7902 | 0.9297 |
bge-m3+sparse | 0.5088 | 0.8944 | 0.6941 | 0.9184 |
JaColBERTv2 | 0.5847 | 0.9185 | 0.6861 | 0.9247 |
multilingual-e5-large | 0.554 | 0.8759 | 0.7722 | 0.8892 |
multilingual-e5-small | 0.4917 | 0.869 | 0.7025 | 0.8565 |
bm25 | 0.458 | 0.8408 | 0.4387 | 0.9002 |
なおこの文章は、私が書いたメモと指示を元に、Claude 3 Opus によって生成された文章を微調整したものである。