概要
LangChainのEnsemble Retrieverの使い方をまとめる。
今回はBM25、HuggingFace(sonoisa)、OpenAI(text-embedding-ada-002)の3つでEnsemble Retrieverを使ってみます。
Ensemble Retriever
検索精度を向上させるために、複数の検索結果を使用して順位を計算します。(ハイブリット検索)
バージョン
langchain==0.1.6
langchain-openai=0.0.5
実装
ライブラリのインポート
import os from langchain.retrievers import BM25Retriever, EnsembleRetriever from langchain_community.vectorstores import FAISS from langchain_openai import OpenAIEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings from sklearn.metrics.pairwise import cosine_similarity from langchain.vectorstores.utils import DistanceStrategy from typing import List from langchain.retrievers.bm25 import default_preprocessing_func
テキストのサンプル
sentences = ["明日の天気は晴れです", "明日の天気は雨です", "晴れが好きです","私はリンゴが好きです"] query1 = "次の日は良い天気です" query2 = "私はミカンが好きです"
OpenAIのAPIKEYを設定
API_KEYは事前に環境変数に登録しておきます。
環境変数の登録方法は以下の記事に書かれています。
api_key = os.environ["OPENAI_API_KEY"]
BM25
BM25はそのまま日本語では使えないようなので、
preprocess_funcに日本語の分かち書きに適した関数を与えます。
以下の記事を参考にしています。
def generate_character_ngrams(text, i, j, binary=False): """ 文字列から指定した文字数のn-gramを生成する関数。 :param text: 文字列データ :param i: n-gramの最小文字数 :param j: n-gramの最大文字数 :param binary: Trueの場合、重複を削除 :return: n-gramのリスト """ ngrams = [] for n in range(i, j + 1): for k in range(len(text) - n + 1): ngram = text[k:k + n] ngrams.append(ngram) if binary: ngrams = list(set(ngrams)) # 重複を削除 return ngrams def preprocess_func(text: str) -> List[str]: i, j = 1, 3 if len(text) < i: return [text] return generate_character_ngrams(text, i, j, True)
bm25_retriever = BM25Retriever.from_texts( sentences, preprocess_func=preprocess_func, metadatas=[{"source":1}]*len(sentences) ) bm25_retriever.k = 4
scores = bm25_retriever.vectorizer.get_scores(preprocess_func(query1)) print("query=", query1) for text, score in zip(sentences, scores): print(text, score)
query= 次の日は良い天気です 明日の天気は晴れです 0.274919929536986 明日の天気は雨です 0.2903841755734415 晴れが好きです 0.2453950779493872 私はリンゴが好きです 0.274919929536986
「次の日は良い天気です」に対して、「明日の天気は雨です」が上位にきました。
正解を選択できませんでした。
scores = bm25_retriever.vectorizer.get_scores(preprocess_func(query2)) print("query=", query2) for text, score in zip(sentences, scores): print(text, score)
query= 私はミカンが好きです 明日の天気は晴れです 0.274919929536986 明日の天気は雨です 0.2903841755734415 晴れが好きです 0.2453950779493872 私はリンゴが好きです 2.681446396908925
「私はミカンが好きです」に対して、「私はリンゴが好きです」が上位にきました。
こちらは正解を選択できています。
HuggingFace
embeddingのモデルはsonoisaを使用しました。 huggingface.co
model_name = "sonoisa/sentence-bert-base-ja-mean-tokens-v2" model_kwargs = {'device': 'cuda'} #cpuの場合は'cpu' encode_kwargs = {'normalize_embeddings': True} embeddings_hf = HuggingFaceBgeEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs ) faiss_vectorstore_hf = FAISS.from_texts(sentences, embeddings_hf, distance_strategy = DistanceStrategy.MAX_INNER_PRODUCT) faiss_retriever_hf = faiss_vectorstore_hf.as_retriever(search_kwargs={"k":4})
result_hf = faiss_retriever_hf.get_relevant_documents(query1) print("query=", query1) for doc in result_hf: print(doc)
query= 次の日は良い天気です page_content='明日の天気は晴れです' page_content='晴れが好きです' page_content='明日の天気は雨です' page_content='私はリンゴが好きです'
「次の日は良い天気です」に対して、「明日の天気は晴れです」が上位にきました。
正解を選択できています。
result_hf = faiss_retriever_hf.get_relevant_documents(query2) print("query=", query2) for doc in result_hf: print(doc)
query= 私はミカンが好きです page_content='私はリンゴが好きです' page_content='晴れが好きです' page_content='明日の天気は雨です' page_content='明日の天気は晴れです'
「私はミカンが好きです」に対して、「私はリンゴが好きです」が上位にきました。
こちらも正解を選択できています。
OpenAI
embeddings_ada = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=api_key) faiss_vectorstore_ada = FAISS.from_texts( sentences, embeddings_ada, metadatas=[{"source":2}]*len(sentences) ) faiss_retriever_ada = faiss_vectorstore.as_retriever(search_kwargs={"k":4}, distance_strategy = DistanceStrategy.MAX_INNER_PRODUCT)
result_ada = faiss_retriever_ada.get_relevant_documents(query1) print("query=", query1) for doc in result_ada: print(doc)
query= 次の日は良い天気です page_content='晴れが好きです' metadata={'source': 2} page_content='明日の天気は晴れです' metadata={'source': 2} page_content='明日の天気は雨です' metadata={'source': 2} page_content='私はリンゴが好きです' metadata={'source': 2}
「次の日は良い天気です」に対して、「晴れが好きです」が上位にきました。
正解を選択できませんでした。
result_ada = faiss_retriever_ada.get_relevant_documents(query2) print("query=", query2) for doc in result_ada: print(doc)
query= 私はミカンが好きです page_content='私はリンゴが好きです' metadata={'source': 2} page_content='晴れが好きです' metadata={'source': 2} page_content='明日の天気は雨です' metadata={'source': 2} page_content='明日の天気は晴れです' metadata={'source': 2}
「私はミカンが好きです」に対して、「私はリンゴが好きです」が上位にきました。
こちらは正解を選択できています。
Ensemble Retriever
上記3つの検索結果を使用してハイブリット検索してみます。
Ensemble Retrieverのコードを見てみると、以下の計算式で最終スコアを計算しているようです。
c: int = 60 rrf_score = weight * (1 / (rank + self.c))
ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever,faiss_retriever_hf], weights=[0.33, 0.33, 0.33] )
docs = ensemble_retriever.invoke(query1) print("query=", query1) docs
query= 次の日は良い天気です [Document(page_content='明日の天気は晴れです'), Document(page_content='晴れが好きです'), Document(page_content='明日の天気は雨です'), Document(page_content='私はリンゴが好きです')]
「次の日は良い天気です」に対して、「明日の天気は晴れです」が上位にきました。
正解を選択できています。
docs = ensemble_retriever.invoke(query2) print("query=", query2) docs
query= 私はミカンが好きです [Document(page_content='私はリンゴが好きです'), Document(page_content='晴れが好きです'), Document(page_content='明日の天気は雨です'), Document(page_content='明日の天気は晴れです')]
「私はミカンが好きです」に対して、「私はリンゴが好きです」が上位にきました。
こちらも正解を選択できています。
結果
sonoisaとensembleで欲しかった結果が得られました。
今回のような単純な文章であればはsonoisa単体でも良い結果が得られましたが、
複雑な文章であれば、ensemble retrieverを使ったほうが検索精度の向上が見込まれると思います。
モデル | 次の日は良い天気です | 私はミカンが好きです |
---|---|---|
BM25 | 明日の天気は雨です | 私はリンゴが好きです |
sonoisa | 明日の天気は晴れです | 私はリンゴが好きです |
ada | 晴れが好きです | 私はリンゴが好きです |
ensemble | 明日の天気は晴れです | 私はリンゴが好きです |