データサイエンス学習記録

ひよっこAIエンジニアの学習記録です。

【LangChain】Ensemble Retrieverを使ってみた。

概要

LangChainのEnsemble Retrieverの使い方をまとめる。

今回はBM25、HuggingFace(sonoisa)、OpenAI(text-embedding-ada-002)の3つでEnsemble Retrieverを使ってみます。

Ensemble Retriever

検索精度を向上させるために、複数の検索結果を使用して順位を計算します。(ハイブリット検索)

python.langchain.com

バージョン

langchain==0.1.6
langchain-openai=0.0.5

実装

ライブラリのインポート

import os
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from sklearn.metrics.pairwise import cosine_similarity
from langchain.vectorstores.utils import DistanceStrategy
from typing import List
from langchain.retrievers.bm25 import default_preprocessing_func

テキストのサンプル

sentences = ["明日の天気は晴れです", "明日の天気は雨です", "晴れが好きです","私はリンゴが好きです"]
query1 = "次の日は良い天気です"
query2 = "私はミカンが好きです"

OpenAIのAPIKEYを設定

API_KEYは事前に環境変数に登録しておきます。
環境変数の登録方法は以下の記事に書かれています。

happy-shibusawake.com

api_key = os.environ["OPENAI_API_KEY"]

BM25

BM25はそのまま日本語では使えないようなので、
preprocess_funcに日本語の分かち書きに適した関数を与えます。
以下の記事を参考にしています。

qiita.com

def generate_character_ngrams(text, i, j, binary=False):
    """
    文字列から指定した文字数のn-gramを生成する関数。

    :param text: 文字列データ
    :param i: n-gramの最小文字数
    :param j: n-gramの最大文字数
    :param binary: Trueの場合、重複を削除
    :return: n-gramのリスト
    """
    ngrams = []
    
    for n in range(i, j + 1):
        for k in range(len(text) - n + 1):
            ngram = text[k:k + n]
            ngrams.append(ngram)
    
    if binary:
        ngrams = list(set(ngrams))  # 重複を削除
    
    return ngrams

def preprocess_func(text: str) -> List[str]:
    i, j = 1, 3
    if len(text) < i:
        return [text]
    return generate_character_ngrams(text, i, j, True)
bm25_retriever = BM25Retriever.from_texts(
    sentences, preprocess_func=preprocess_func, metadatas=[{"source":1}]*len(sentences)
)
bm25_retriever.k = 4
scores = bm25_retriever.vectorizer.get_scores(preprocess_func(query1))
print("query=", query1)
for text, score in zip(sentences, scores):
    print(text, score)
query= 次の日は良い天気です
明日の天気は晴れです 0.274919929536986
明日の天気は雨です 0.2903841755734415
晴れが好きです 0.2453950779493872
私はリンゴが好きです 0.274919929536986

「次の日は良い天気です」に対して、「明日の天気は雨です」が上位にきました。
正解を選択できませんでした。

scores = bm25_retriever.vectorizer.get_scores(preprocess_func(query2))
print("query=", query2)
for text, score in zip(sentences, scores):
    print(text, score)
query= 私はミカンが好きです
明日の天気は晴れです 0.274919929536986
明日の天気は雨です 0.2903841755734415
晴れが好きです 0.2453950779493872
私はリンゴが好きです 2.681446396908925

「私はミカンが好きです」に対して、「私はリンゴが好きです」が上位にきました。
こちらは正解を選択できています。

HuggingFace

embeddingのモデルはsonoisaを使用しました。 huggingface.co

model_name = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"
model_kwargs = {'device': 'cuda'} #cpuの場合は'cpu'
encode_kwargs = {'normalize_embeddings': True}
embeddings_hf = HuggingFaceBgeEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)

faiss_vectorstore_hf = FAISS.from_texts(sentences, embeddings_hf, distance_strategy = DistanceStrategy.MAX_INNER_PRODUCT)
faiss_retriever_hf = faiss_vectorstore_hf.as_retriever(search_kwargs={"k":4})
result_hf = faiss_retriever_hf.get_relevant_documents(query1)
print("query=", query1)
for doc in result_hf:
    print(doc)
query= 次の日は良い天気です
page_content='明日の天気は晴れです'
page_content='晴れが好きです'
page_content='明日の天気は雨です'
page_content='私はリンゴが好きです'

「次の日は良い天気です」に対して、「明日の天気は晴れです」が上位にきました。
正解を選択できています。

result_hf = faiss_retriever_hf.get_relevant_documents(query2)
print("query=", query2)
for doc in result_hf:
    print(doc)
query= 私はミカンが好きです
page_content='私はリンゴが好きです'
page_content='晴れが好きです'
page_content='明日の天気は雨です'
page_content='明日の天気は晴れです'

「私はミカンが好きです」に対して、「私はリンゴが好きです」が上位にきました。
こちらも正解を選択できています。

OpenAI

embeddings_ada = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=api_key)
faiss_vectorstore_ada = FAISS.from_texts(
    sentences, embeddings_ada, metadatas=[{"source":2}]*len(sentences)
)

faiss_retriever_ada = faiss_vectorstore.as_retriever(search_kwargs={"k":4}, distance_strategy = DistanceStrategy.MAX_INNER_PRODUCT)
result_ada = faiss_retriever_ada.get_relevant_documents(query1)
print("query=", query1)
for doc in result_ada:
    print(doc)
query= 次の日は良い天気です
page_content='晴れが好きです' metadata={'source': 2}
page_content='明日の天気は晴れです' metadata={'source': 2}
page_content='明日の天気は雨です' metadata={'source': 2}
page_content='私はリンゴが好きです' metadata={'source': 2}

「次の日は良い天気です」に対して、「晴れが好きです」が上位にきました。
正解を選択できませんでした。

result_ada = faiss_retriever_ada.get_relevant_documents(query2)
print("query=", query2)
for doc in result_ada:
    print(doc)
query= 私はミカンが好きです
page_content='私はリンゴが好きです' metadata={'source': 2}
page_content='晴れが好きです' metadata={'source': 2}
page_content='明日の天気は雨です' metadata={'source': 2}
page_content='明日の天気は晴れです' metadata={'source': 2}

「私はミカンが好きです」に対して、「私はリンゴが好きです」が上位にきました。
こちらは正解を選択できています。

Ensemble Retriever

上記3つの検索結果を使用してハイブリット検索してみます。
Ensemble Retrieverのコードを見てみると、以下の計算式で最終スコアを計算しているようです。

c: int = 60
rrf_score = weight * (1 / (rank + self.c))
ensemble_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, faiss_retriever,faiss_retriever_hf], weights=[0.33, 0.33, 0.33]
)
docs = ensemble_retriever.invoke(query1)
print("query=", query1)
docs
query= 次の日は良い天気です
[Document(page_content='明日の天気は晴れです'),
 Document(page_content='晴れが好きです'),
 Document(page_content='明日の天気は雨です'),
 Document(page_content='私はリンゴが好きです')]

「次の日は良い天気です」に対して、「明日の天気は晴れです」が上位にきました。
正解を選択できています。

docs = ensemble_retriever.invoke(query2)
print("query=", query2)
docs
query= 私はミカンが好きです
[Document(page_content='私はリンゴが好きです'),
 Document(page_content='晴れが好きです'),
 Document(page_content='明日の天気は雨です'),
 Document(page_content='明日の天気は晴れです')]

「私はミカンが好きです」に対して、「私はリンゴが好きです」が上位にきました。
こちらも正解を選択できています。

結果

sonoisaとensembleで欲しかった結果が得られました。
今回のような単純な文章であればはsonoisa単体でも良い結果が得られましたが、
複雑な文章であれば、ensemble retrieverを使ったほうが検索精度の向上が見込まれると思います。

モデル 次の日は良い天気です 私はミカンが好きです
BM25 明日の天気は雨です 私はリンゴが好きです
sonoisa 明日の天気は晴れです 私はリンゴが好きです
ada 晴れが好きです 私はリンゴが好きです
ensemble 明日の天気は晴れです 私はリンゴが好きです