SVMを使った類似 embeddings 検索 - kNN ではない類似検索の選択肢
LangChain v0.0.141 に SVM Retriever という実装が入った。これは embeddings(集合)から、単一 embedding と類似しているもの top-K を SVM を使って見つけるという実装で、えっどうやってるの?と追っかけてみたら、知らない知識で面白かったのでメモ記事に。
kNN vs SVM
この実装の元となった、knn_vs_svm.ipynbというnotebookがあって、冒頭を機械翻訳すると以下となる。
よくあるワークフローは、あるデータを埋め込みに基づいてインデックス化し、新しいクエリの埋め込みがあれば、k-Nearest Neighbor検索で最も類似した例を検索することです。例えば、大規模な論文コレクションをその抄録に基づいて埋め込み、興味のある新しい論文を与えると、その論文に最も類似した論文を検索することが想像できます。
私の経験では、若干の計算量に余裕があれば、kNNの代わりにSVMを使用した方が、常にうまくいきます。以下に例を挙げます:
k近傍法(KNN)ではユークリッド距離を用いて計算するが、SVMを使うというアプローチ。この SVM の利用方法が面白くてnotebookから引用すると、以下のようになっている。
# Wired: use an SVM
from sklearn import svm
# create the "Dataset"
x = np.concatenate([query[None,...], embeddings]) # x is (1001, 1536) array, with query now as the first row
y = np.zeros(1001)
y[0] = 1 # we have a single positive example, mark it as such
# train our (Exemplar) SVM
# docs: https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
clf.fit(x, y) # train
# infer on whatever data you wish, e.g. the original data
similarities = clf.decision_function(x)
sorted_ix = np.argsort(-similarities)
print("top 10 results:")
for k in sorted_ix[:10]:
print(f"row {k}, similarity {similarities[k]}")
対象となる単一のembeddingにのみ1
、他はすべて0
をつけたラベルでLinearSVCで分類課題として学習する。確信スコアとしては 1~-1
が得られるので、embeddingに近いつまり1に近いものから top-K を抽出することで類似度が高いものが得られる。
単純なユークリッド距離ではなく、SVMのカーネルトリックを用いた空間を考慮してのスコア算出になるので、たしかによりよい結果になりそうだ。LinearSVCを使ってそれを計算するのはなるほど!という感じであった。この手法をLangChainでいい感じに使えるように抽象化したものが、SVM Retrieverの実装というわけだ。
実際に kNN と SVM の結果と比較する
では AI Newsのデータのうち、日本語のデータ450件を使って、特定クエリでkNNとSVMで類似コンテンツを検索した結果を見てみよう。
query: 生成AIと著作権
=== kNN ===
0.886: 生成AIの猛烈な進化と著作権制度~技術発展と著作権者の利益のバランスをとるには~ | STORIA法律事務所
0.880: スター・ウォーズやハリポタの人気キャラと話せるAIの「著作権問題」をどう考えるべきか | シリコンバレーの「生き字引」がズバリ指摘 | クーリエ・ジャポン
0.876: 生成AIの利用ガイドライン作成のための手引き | STORIA法律事務所
0.876: ダブスタクソイナゴは生成AIの法的議論に参加してくるんじゃねえ!!
0.874: 画像生成AI “クリエーターの権利脅かされる” 法整備など提言 | NHK | AI(人工知能)
0.870: 【AI】生成AIを利用する場合に気を付けなければならない著作権の知識|福岡真之介|note
0.868: AIイラストに規制を求める団体の理事「木目百二」氏が二次創作のガイドライン違反で支援サイトの作品全消し、謝罪に追い込まれる - Togetter
0.865: 生成AI「開発規制、望ましくない」 松本総務相 - 日本経済新聞
=== SVM ===
-0.305: 生成AIで作品、それって著作権侵害? 福井健策弁護士に聞く:朝日新聞デジタル
-0.384: 生成AIの猛烈な進化と著作権制度~技術発展と著作権者の利益のバランスをとるには~ | STORIA法律事務所
-0.402: ダブスタクソイナゴは生成AIの法的議論に参加してくるんじゃねえ!!
-0.408: AIイラストに規制を求める団体の理事「木目百二」氏が二次創作のガイドライン違反で支援サイトの作品全消し、謝罪に追い込まれる - Togetter
-0.436: 画像生成AIによる作品の無許可使用を主張した写真家が逆に損害賠償を請求される - GIGAZINE
-0.479: アーティストのGrimes、生成AIで自分の声を自由に使っていいとツイート - ITmedia NEWS
-0.482: 生成AIの利用ガイドライン作成のための手引き | STORIA法律事務所
-0.483: スター・ウォーズやハリポタの人気キャラと話せるAIの「著作権問題」をどう考えるべきか | シリコンバレーの「生き字引」がズバリ指摘 | クーリエ・ジャポン
上記結果だと正直どっちでもパッと見は適切のように見える。もうちょっと難しそうなクエリで見てみる。
query: 大規模言語モデルを低スペックのマシンで動かしたい
=== kNN ===
0.872: RWKV14Bを日本語AlpacaデータセットでLoRAして公開しました(ご家庭で動く!?)|shi3z|note
0.861: チャットAIをブラウザのWebGPUだけで実行でき日本語も使用できる「Web LLM」、実際に試してみる方法はこんな感じ - GIGAZINE
0.855: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
0.853: ChatGPT対抗のオープンソース言語モデル「StableLM」。日本語版も? - PC Watch
0.851: “画像の面白さ”を解説できるAI「MiniGPT-4」 写真からラップや詩、料理レシピ作成 デモサイトも公開中:Innovative Tech(1/2 ページ) - ITmedia NEWS
0.850: チャットAI「StableLM」発表 オープンソースモデルで商用可 「Stable Diffusion」開発元から - ITmedia NEWS
0.849: Googleの大規模言語モデル「Bard」、日本でも利用可能に。英語のみだが、改良されたPaLMベース | テクノエッジ TechnoEdge
0.849: Stability AIがオープンソースで商用利用も可能な大規模言語モデル「StableLM」をリリース - GIGAZINE
=== SVM ===
-0.359: 大規模言語モデルを自社でトレーニング&活用する方法|mah_lab / 西見 公宏|note
-0.366: 大規模言語モデル間の性能比較まとめ|mah_lab / 西見 公宏|note
-0.451: 深層学習コンパイラスタックと最適化
-0.456: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
-0.471: dolly-v2-12bという120億パラメータの言語モデルを使ってみた!|Masayuki Abe|note
-0.490: Googleの大規模言語モデル「Bard」、日本でも利用可能に。英語のみだが、改良されたPaLMベース | テクノエッジ TechnoEdge
-0.504: RWKV14Bを日本語AlpacaデータセットでLoRAして公開しました(ご家庭で動く!?)|shi3z|note
-0.510: Webブラウザ上で3D/2Dモデルをぬるぬる動かせる「Babylon.js 6」正式版に。レンダリング性能が最大50倍、WASM化した物理演算エンジン搭載、液体のレンダリングも - Publickey
クエリによっては結構結果がバラける。kNNとSVMの結果のアンサンブルによるハイブリッド検索も実装してみたので、それで見てみる。
query: 大規模言語モデルを低スペックのマシンで動かしたい
=== kNN ===
-3.816: RWKV14Bを日本語AlpacaデータセットでLoRAして公開しました(ご家庭で動く!?)|shi3z|note
-3.527: チャットAIをブラウザのWebGPUだけで実行でき日本語も使用できる「Web LLM」、実際に試してみる方法はこんな感じ - GIGAZINE
-2.920: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
-2.591: ChatGPT対抗のオープンソース言語モデル「StableLM」。日本語版も? - PC Watch
-2.436: “画像の面白さ”を解説できるAI「MiniGPT-4」 写真からラップや詩、料理レシピ作成 デモサイトも公開中:Innovative Tech(1/2 ページ) - ITmedia NEWS
=== SVM ===
-3.923: 大規模言語モデルを自社でトレーニング&活用する方法|mah_lab / 西見 公宏|note
-3.865: 大規模言語モデル間の性能比較まとめ|mah_lab / 西見 公宏|note
-3.140: 深層学習コンパイラスタックと最適化
-3.097: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
-2.962: dolly-v2-12bという120億パラメータの言語モデルを使ってみた!|Masayuki Abe|note
=== Hybrid ===
-3.869: 大規模言語モデルを自社でトレーニング&活用する方法|mah_lab / 西見 公宏|note
-3.102: RWKV14Bを日本語AlpacaデータセットでLoRAして公開しました(ご家庭で動く!?)|shi3z|note
-2.913: 大規模言語モデル間の性能比較まとめ|mah_lab / 西見 公宏|note
-2.844: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
-2.558: チャットAIをブラウザのWebGPUだけで実行でき日本語も使用できる「Web LLM」、実際に試してみる方法はこんな感じ - GIGAZINE
おー、より良い感じの結果になったと思う。簡単に使えるので、kNN での検索や類似度探索に加え、SVMも使ってみるというのも悪くなさそうな感じだなぁ。もちろん速度はkNNのほうが圧倒的に速いと思うが、SVMも現実的な速度で使える場合は良い気がしている。
おまけコード
embsはembeddingsの配列で別途作る必要がある。textsはembsとペアのデータ。なおLangChainのSVM Retrieverを使うともっと簡単にかけるが、スコアなどが手に入らないので自前で実装してる。
# base: https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb
from sklearn import svm
import numpy as np
from langchain.embeddings import OpenAIEmbeddings
def knn_top_k(query_emb, embs, k=10):
l2_embs = embs / np.sqrt((embs**2).sum(1, keepdims=True))
l2_query = query_emb / np.sqrt((query_emb**2).sum())
similarities = l2_embs.dot(l2_query)
sorted_index = np.argsort(-similarities)
res_index = sorted_index[1:k+1]
return res_index, similarities[res_index], -similarities
def svm_top_k(query_emb, embs, k=10):
X = np.concatenate([query_emb[None, ...], embs])
y = np.zeros(X.shape[0])
y[0] = 1
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
clf.fit(X, y)
similarities = clf.decision_function(X)
sorted_index = np.argsort(-similarities)
res_index = sorted_index[1:k+1] - 1
return res_index, similarities[res_index + 1], -similarities[1:]
def get_query_emb(text):
emb = OpenAIEmbeddings().embed_query(text) # type: ignore
return np.array(emb)
def join_colon(num_list_a, list_b):
return [f'{a:.3f}: {b}' for a, b in zip(num_list_a, list_b)]
def knn_svm(text, embs, texts, k=5):
query_emb = get_query_emb(text)
knn_index, knn_similarities, _ = knn_top_k(query_emb, embs, k)
svm_index, svm_similarities, _ = svm_top_k(query_emb, embs, k)
print('query: ', text)
print('=== kNN ===')
print("\n".join(join_colon(knn_similarities, texts[knn_index])))
print('=== SVM ===')
print("\n".join(join_colon(svm_similarities, texts[svm_index])))
def hyblid_knn_svm(text_or_emb, embs, texts, k=5):
if isinstance(text_or_emb, str):
query_emb = get_query_emb(text_or_emb)
print('query: ', text_or_emb) # type: ignore
else:
query_emb = text_or_emb
# 全件取得する
knn_index, knn_similarities, knn_all_scores = knn_top_k(query_emb, embs, embs.shape[0])
svm_index, svm_similarities, svm_all_scores = svm_top_k(query_emb, embs, embs.shape[0])
# score を正規化する
knn_score_normalized = (knn_all_scores - np.mean(knn_all_scores)) / np.std(knn_all_scores)
svm_score_normalized = (svm_all_scores - np.mean(svm_all_scores)) / np.std(svm_all_scores)
# それぞれのスコアを足し合わせて、ハイブリッドなスコアを作る
hybrid_similarities = (knn_score_normalized + svm_score_normalized) / 2
hybrid_index = np.argsort(hybrid_similarities)[:k]
print('=== kNN ===')
print("\n".join(join_colon(np.sort(knn_score_normalized)[:k], texts[knn_index][:k])))
print('=== SVM ===')
print("\n".join(join_colon(np.sort(svm_score_normalized)[:k], texts[svm_index][:k])))
print('=== Hybrid ===')
print("\n".join(join_colon(hybrid_similarities[hybrid_index][:k], texts[hybrid_index][:k])))