Source code for fenn.agents.rag.rag_node
from fenn.agents import Node
from .loader import load_documents
from .retriever import Retriever
[docs]
class RAGNode(Node):
"""
Flow node that retrieves relevant context from indexed sources.
Loads and indexes all sources once at construction time, then per run
queries the index using ``shared[query_key]`` and writes the results
into ``shared[chunks_key]`` and ``shared[context_key]``.
Parameters
----------
sources : str or list of str, optional
File paths, folder paths, or URLs to load and index on init.
Additional sources can be indexed later with :meth:`add_source`.
query_key : str
Key in ``shared`` that holds the user query. Default: ``"query"``.
context_key : str
Key written into ``shared`` with the concatenated chunk text.
Default: ``"rag_context"``.
chunks_key : str
Key written into ``shared`` with the raw list of chunks.
Default: ``"rag_chunks"``.
top_k : int
Maximum number of chunks to retrieve. Default: 5.
next_action : str
Action string returned by ``post()``, used by ``Flow.get_next_node()``.
Default: ``"default"``.
faiss : bool
Use FAISS semantic search instead of BM25. Default: False.
embedding_provider : str
Embedding provider (only used when faiss=True). Default: ``"local"``.
embedding_model : str
Embedding model (only used when faiss=True). Default: ``"all-MiniLM-L6-v2"``.
embedding_api_key : str, optional
API key for the embedding provider.
chunk_mode : str
Document chunking strategy. One of ``"smart"``, ``"paragraphs"``,
``"sentences"``, ``"fixed"``. Default: ``"smart"``.
persist_path : str or Path, optional
Directory to save/load the FAISS index. Only used when faiss=True.
"""
[docs]
def __init__(
self,
sources=None,
query_key="query",
context_key="rag_context",
chunks_key="rag_chunks",
top_k=5,
next_action="default",
faiss=False,
embedding_provider="local",
embedding_model="all-MiniLM-L6-v2",
embedding_api_key=None,
chunk_mode="smart",
persist_path=None,
):
super().__init__()
self._query_key = query_key
self._context_key = context_key
self._chunks_key = chunks_key
self._top_k = top_k
self._next_action = next_action
self._retriever = Retriever(
use_faiss=faiss,
embedding_provider=embedding_provider,
embedding_model=embedding_model,
embedding_api_key=embedding_api_key,
chunk_mode=chunk_mode,
persist_path=persist_path,
)
if sources:
if isinstance(sources, str):
sources = [sources]
for s in sources:
self._retriever.index(load_documents(s))
[docs]
def add_source(self, source):
"""Index an additional source. Returns self for chaining."""
self._retriever.index(load_documents(source))
return self
[docs]
def prep(self, shared):
return shared.get(self._query_key, "")
[docs]
def exec(self, query):
return self._retriever.query(query, top_k=self._top_k)
[docs]
def post(self, shared, query, chunks):
shared[self._chunks_key] = chunks
shared[self._context_key] = "\n\n".join(chunks)
return self._next_action