A Day in the Life

日本語 Reranker 作成のテクニカルレポート

本記事は、日本語の reranker (CrossEncoder) モデルを作成における技術レポートである。reranker とは何か、といった内容は別記事 日本語最高性能のRerankerをリリース / そもそも Reranker とは? を参照のこと。

なお今回作ったモデル一覧は以下。

モデル名 layers hidden_size
hotchpotch/japanese-reranker-cross-encoder-xsmall-v1 6 384
hotchpotch/japanese-reranker-cross-encoder-small-v1 12 384
hotchpotch/japanese-reranker-cross-encoder-base-v1 12 768
hotchpotch/japanese-reranker-cross-encoder-large-v1 24 1024
hotchpotch/japanese-bge-reranker-v2-m3-v1 24 1024

CrossEncoder の学習方法

CrossEncoder は、単純な回帰タスクである。query text[SEP]passage text といったSEPトークン等で区切ったテキストを、正例は1.0 負例は0.0としてラベル付けし学習させる。具体的な学習コード例としては、SentenceTransformers の CrossEncoder 学習サンプルが分かりやすい。

また、複数の負例(ハードネガティブ)を正例と同一バッチで学習させることで性能が大きく向上する。この学習方法については、FlagEmbedding の reranker trainerが参考になる。

学習用データセット

学習には、質問と正例・負例のデータセットが必要である。1件につき、positive(正例)1個とhard-negative(負例)15個を1セットとし、1グループ=16個として学習に用いた。以下のデータセットを利用した。

  • JQaRA: dev/unused から 7,270件
  • JSQuAD:
    • train から 62,859件
    • hard-negative マイニング用に wikipedia の文章から追加
  • miracl: train の日本語データから 6,984件
  • mmarco: train の日本語データからフィルターした 346,413件
  • mr_tydi:
    • train の日本語データから 3,697件
    • なお miracl の日本語データには、このmr_tydiのデータと重複したデータが多く含まれる
  • wikipedia リード文:
    • wikipedia のタイトルと、冒頭のリード文をペアとした 40,130件
    • hard-negative マイニングでは、同様に wikipedia のリード文のみを対象にマイニング

評価用データセット

モデルの評価には、以下のデータセットを用いた。

  • JQaRA:
    • test 2000件
    • 評価指標は JQaRA での評価方法として定義されているNDCG@10
  • JSQuAD:
    • validation 4442件
    • wikipediaからhard-negativeマイニングで negatives 19件追加し、合計20件からのMAP@10で評価
  • miracl:
    • dev から negatives が9件以上のデータでフィルターした、704件
    • positive1件、negatives9件の合計10件としてMAP@10で評価
    • なお miracl に日本語データでは devtrain で一部データが被っており、train を学習すればするほど dev の評価が高くなりやすい
  • JaCWIR:
    • eval 5000件
    • 評価指標は JaCWIR Reranker 評価方法として定義されているMAP@10

ハードネガティブマイニング

ハードネガティブとは、モデルが正例と誤判断しやすいが実際は負例であるデータを指す。これらを積極的に「マイニング」することで、学習データの多様性と難易度を高め、モデルの精度向上が期待される。

本モデルでは、BM25と複数のSentenceTransformerモデルを用いてハードネガティブをマイニングした。Semantic Textual Similarity(StS)タスクにより、正例に意味的に類似するが実際は負例である文章を抽出した。類似度の高いデータからランダムにサンプリングする方法を採用した。

学習元のpre-trainモデル

以下のpre-trainモデルを学習のpre-trainモデルとして利用した。なお BAAI/bge-reranker-v2-m3 については全件学習させると汎化性能が低下したため、mmarco, JSQuAD, wikipedia リード文を各1万件にランダムサンプリング(他のデータセットは全件)したデータで学習させた。

過学習への対応

CrossEncoderの学習を進める中で、ハードネガティブにwikipediaの文章を使用しているため、wikipediaデータを利用した関連のタスクの評価(JQaRA, JSQuAD, miracl japaneseなど)には最適化されるが、wikipedia以外のドメインでの汎化性能が学習すればするほど低下することが判明した。そこで、学習データに含まれないドメイン外のデータセットであるJaCWIRを作成し、バランスをとりながら学習・評価を行った。

結果として、1 epoch以上の学習では過学習が発生したため、学習は1 epochのみに制限している。

学習パラメータ

主のモデルの学習には、主に以下のパラメータを使用した。

  • batch_size: 512 (gradient_accumulation)
    • 16個が1グループなので、pos,neg合わせて 512 * 16 = 8192 を1バッチで学習
  • warmup_ratio: 0.25 (全体の25%をwarmupに使用)
  • スケジューラ: cosine
  • オプティマイザ: paged_adamw_32bit
  • learning_rate:
    • xsmall = 2e-04
    • small = 5e-04
    • base = 8e-05
    • large = 3e-05
  • loss
    • CrossEntropy

largeモデルを教師モデルとして使用

xsmall, smallの学習では、japanese-reranker-cross-encoder-large-v1japanese-bge-reranker-v2-m3-v1の推論出力を教師ラベルとして追加利用した。教師モデルの出力は推論値(例: pos=0.98, negs=[0.02, 0.07, ...])となるため、01だけでなく回帰タスクの連続値としての利用が可能である。教師モデルの出力データを用いることで、大幅ではないが若干のスコア向上が観測された。なお、この学習にはlossはMSEを用いた。

mixモデルの作成

学習データセットやスコアパラメータ、シード値を変更することで、多様な学習結果が得られる。これらの個別に学習したモデルを単純に線形結合することで、多様性を持たせパフォーマンスを向上させることができる。今回、複数の学習済みモデルを結合することでスコアの向上を確認した。なおモデル合成のツールにはLM_Cocktailを利用した。

注意事項としては、合成後のモデルは出力値の標準偏差が小さくなるため、量子化時等になんらかの性能劣化が発生する可能があるかもしれない。

評価結果

作成したCrossEncoderモデルの評価結果は以下の通りである。BAAI/bge-reranker-v2-m3は元々のマルチリンガル言語に対しての汎化性能が高く初めから日本語に対して高性能で、モデルサイズが問題にならなければ、少量のサンプル(数千件程度)でも微調整可能なため reranker 学習元の微調整モデルとしては最適と考えられる。

なお、この評価データセットのスコアはそのデータセットが公開しているtrainデータ等で学習することでスコアが高く出る傾向にある。今回作ったモデルは、JaCWIR 以外はtrain等のデータで傾向を学習しているため、その点も評価スコアを見る際には留意すると良いであろう。

Model Name JQaRA JaCWIR MIRACL JSQuAD
japanese-reranker-cross-encoder-xsmall-v1 0.6136 0.9376 0.7411 0.9602
japanese-reranker-cross-encoder-small-v1 0.6247 0.939 0.7776 0.9604
japanese-reranker-cross-encoder-base-v1 0.6711 0.9337 0.818 0.9708
japanese-reranker-cross-encoder-large-v1 0.7099 0.9364 0.8406 0.9773
japanese-bge-reranker-v2-m3-v1 0.6918 0.9372 0.8423 0.9624
bge-reranker-v2-m3 0.673 0.9343 0.8374 0.9599
bge-reranker-large 0.4718 0.7332 0.7666 0.7081
bge-reranker-base 0.2445 0.4905 0.6792 0.5757
cross-encoder-mmarco-mMiniLMv2-L12-H384-v1 0.5588 0.9211 0.7158 0.932
shioriha-large-reranker 0.5775 0.8458 0.8084 0.9262
bge-m3+all 0.576 0.904 0.7926 0.9226
bge-m3+dense 0.539 0.8642 0.7753 0.8815
bge-m3+colbert 0.5656 0.9064 0.7902 0.9297
bge-m3+sparse 0.5088 0.8944 0.6941 0.9184
JaColBERTv2 0.5847 0.9185 0.6861 0.9247
multilingual-e5-large 0.554 0.8759 0.7722 0.8892
multilingual-e5-small 0.4917 0.869 0.7025 0.8565
bm25 0.458 0.8408 0.4387 0.9002

なおこの文章は、私が書いたメモと指示を元に、Claude 3 Opus によって生成された文章を微調整したものである。

記事の一覧 >

関連するかもエントリー

日本語最高性能のRerankerをリリース / そもそも Reranker とは?
日本語に特化した形で学習されたRerankerがほとんど無かったので、日本語を適切に学習させた Reranker ファミリーを作りました。小さいモデルから大きなモデルまで揃っています。評価性能は以下の通りで、現在(2024年4月頭)に公開されているRerank日本語タスクにおいて...
日本語に特化した形で学習されたRerankerがほとんど無かったので、日本語を適切に学習させた Reranker ファミリーを作りました。小...
Q&A + RAG に特化したLLMをSFTで学習させ4bit量子化モデルを作り、GPT3.5以上の性能を7Bモデルで達成する
この記事は、LLM Advent Calendar 2023の12月15日の記事である。先日、RAG用途に使える、Wikipedia 日本語の embeddings とベクトル検索用の faiss index を作ったのだけど、質問に対しての回答の抽出に GPT 3.5 / 4 ...
この記事は、LLM Advent Calendar 2023の12月15日の記事である。先日、RAG用途に使える、Wikipedia 日本語...
Kaggle コンペ Feedback Prize - English Language Learning でチーム参加15位金メダル取得で、Kaggle Master へ
Kaggle のコンペティション、Feedback Prize - English Language Learningが終わり、約2650チーム中15位で金メダル取得となった。これで合計金メダル2つ、銀メダル1つを取得し、Kaggle Competitions Master の条...
Kaggle のコンペティション、Feedback Prize - English Language Learningが終わり、約2650チ...