import asyncio
import copy
import time
import warnings
_TERMINAL = object() # sentinel for explicit terminal transitions in Flow.connect
[docs]
class BaseNode:
[docs]
def __init__(self):
self.params, self.successors = {}, {}
[docs]
def set_params(self, params):
self.params = params
[docs]
def prep(self, shared):
pass
[docs]
def exec(self, prep_res):
pass
[docs]
def post(self, shared, prep_res, exec_res):
pass
def _exec(self, prep_res):
return self.exec(prep_res)
def _run(self, shared):
p = self.prep(shared)
e = self._exec(p)
return self.post(shared, p, e)
[docs]
def run(self, shared):
if self.successors:
warnings.warn("Node won't run successors. Use Flow.")
return self._run(shared)
[docs]
class Node(BaseNode):
[docs]
def __init__(self, max_retries=1, wait=0):
super().__init__()
self.max_retries, self.wait = max_retries, wait
[docs]
def exec_fallback(self, prep_res, exc):
raise exc
def _exec(self, prep_res):
for self.cur_retry in range(self.max_retries):
try:
return self.exec(prep_res)
except Exception as e:
if self.cur_retry == self.max_retries - 1:
return self.exec_fallback(prep_res, e)
if self.wait > 0:
time.sleep(self.wait)
[docs]
class BatchNode(Node):
def _exec(self, items):
return [super(BatchNode, self)._exec(i) for i in (items or [])]
[docs]
class Flow(BaseNode):
[docs]
def __init__(self, start=None):
super().__init__()
self.start_node = start
[docs]
def start(self, start):
self.start_node = start
return start
[docs]
def connect(self, src, dst, action="default"):
if action in src.successors:
warnings.warn(f"Overwriting successor for action '{action}'")
src.successors[action] = _TERMINAL if dst is None else dst
return self
[docs]
def get_next_node(self, curr, action):
nxt = curr.successors.get(action or "default")
if nxt is _TERMINAL:
return None
if nxt is None and curr.successors:
warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
return nxt
def _orch(self, shared, params=None):
curr, p, last_action = (
copy.copy(self.start_node),
(params or {**self.params}),
None,
)
while curr:
curr.set_params(p)
last_action = curr._run(shared)
curr = copy.copy(self.get_next_node(curr, last_action))
return last_action
def _run(self, shared):
p = self.prep(shared)
o = self._orch(shared)
return self.post(shared, p, o)
[docs]
def post(self, shared, prep_res, exec_res):
return exec_res
[docs]
class BatchFlow(Flow):
def _run(self, shared):
pr = self.prep(shared) or []
for bp in pr:
self._orch(shared, {**self.params, **bp})
return self.post(shared, pr, None)
[docs]
class AsyncNode(Node):
[docs]
async def prep_async(self, shared):
pass
[docs]
async def exec_async(self, prep_res):
pass
[docs]
async def exec_fallback_async(self, prep_res, exc):
raise exc
[docs]
async def post_async(self, shared, prep_res, exec_res):
pass
async def _exec(self, prep_res):
for self.cur_retry in range(self.max_retries):
try:
return await self.exec_async(prep_res)
except Exception as e:
if self.cur_retry == self.max_retries - 1:
return await self.exec_fallback_async(prep_res, e)
if self.wait > 0:
await asyncio.sleep(self.wait)
[docs]
async def run_async(self, shared):
if self.successors:
warnings.warn("Node won't run successors. Use AsyncFlow.")
return await self._run_async(shared)
async def _run_async(self, shared):
p = await self.prep_async(shared)
e = await self._exec(p)
return await self.post_async(shared, p, e)
def _run(self, shared):
raise RuntimeError("Use run_async.")
[docs]
class AsyncBatchNode(AsyncNode, BatchNode):
async def _exec(self, items):
return [await super(AsyncBatchNode, self)._exec(i) for i in items]
[docs]
class AsyncParallelBatchNode(AsyncNode, BatchNode):
async def _exec(self, items):
return await asyncio.gather(
*(super(AsyncParallelBatchNode, self)._exec(i) for i in items)
)
[docs]
class AsyncFlow(Flow, AsyncNode):
async def _orch_async(self, shared, params=None):
curr, p, last_action = (
copy.copy(self.start_node),
(params or {**self.params}),
None,
)
while curr:
curr.set_params(p)
last_action = (
await curr._run_async(shared)
if isinstance(curr, AsyncNode)
else curr._run(shared)
)
curr = copy.copy(self.get_next_node(curr, last_action))
return last_action
async def _run_async(self, shared):
p = await self.prep_async(shared)
o = await self._orch_async(shared)
return await self.post_async(shared, p, o)
[docs]
async def post_async(self, shared, prep_res, exec_res):
return exec_res
[docs]
class AsyncBatchFlow(AsyncFlow, BatchFlow):
async def _run_async(self, shared):
pr = await self.prep_async(shared) or []
for bp in pr:
await self._orch_async(shared, {**self.params, **bp})
return await self.post_async(shared, pr, None)
[docs]
class AsyncParallelBatchFlow(AsyncFlow, BatchFlow):
async def _run_async(self, shared):
pr = await self.prep_async(shared) or []
await asyncio.gather(
*(self._orch_async(shared, {**self.params, **bp}) for bp in pr)
)
return await self.post_async(shared, pr, None)
from .llm import LLMClient # noqa: E402 - avoid circular import with .llm
from .rag import RAGNode # noqa: E402 - avoid circular import with .rag
__all__ = [
"BaseNode",
"Node",
"BatchNode",
"Flow",
"BatchFlow",
"AsyncNode",
"AsyncBatchNode",
"AsyncParallelBatchNode",
"AsyncFlow",
"AsyncBatchFlow",
"AsyncParallelBatchFlow",
"LLMClient",
"RAGNode",
]