LLM planner, not perfect but ok

This commit is contained in:
Steve Androulakis
2025-01-01 16:57:08 -08:00
parent 33af355363
commit 245d64fca9
8 changed files with 275 additions and 123 deletions

View File

@@ -102,16 +102,16 @@ def get_current_date_human_readable():
@activity.defn(dynamic=True)
def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
"""Invoked for an unknown activity type, delegates to the correct tool."""
from tools import get_handler # import the registry function
from tools import get_handler
tool_name = activity.info().activity_type # e.g. "SearchFlights"
tool_name = activity.info().activity_type # e.g. "FindEvents"
tool_args = activity.payload_converter().from_payload(args[0].payload, dict)
activity.logger.info(f"Running dynamic tool '{tool_name}' with args: {tool_args}")
activity.logger.info(f"Dynamic activity triggered for tool: {tool_name}")
handler_func = get_handler(tool_name)
# Delegate to the tool's function
result = handler_func(tool_args)
# Delegate to the relevant function
handler = get_handler(tool_name)
result = handler(tool_args)
# Optionally log or augment the result
activity.logger.info(f"Tool '{tool_name}' result: {result}")
return result

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List
from dataclasses import dataclass, field
from typing import List, Dict, Any
@dataclass
@@ -19,3 +19,14 @@ class ToolDefinition:
@dataclass
class ToolsData:
tools: List[ToolDefinition]
@dataclass
class ToolInvocation:
tool: str
args: Dict[str, Any]
@dataclass
class MultiToolSequence:
tool_invocations: List[ToolInvocation] = field(default_factory=list)

View File

@@ -6,54 +6,68 @@ def generate_genai_prompt_from_tools_data(
) -> str:
"""
Generates a prompt describing the tools and the instructions for the AI
assistant, using the conversation history provided.
assistant, using the conversation history provided, allowing for multiple
tools and a 'done' state.
"""
prompt_lines = []
prompt_lines.append(
"You are an AI assistant that must determine all required arguments"
"You are an AI assistant that must determine all required arguments "
"for the tools to achieve the user's goal. "
)
prompt_lines.append("")
prompt_lines.append(
"Conversation history so far. \nANALYZE THIS HISTORY TO DETERMINE WHICH ARGUMENTS TO PRE-FILL AS SPECIFIED FOR THE TOOL BELOW: "
)
prompt_lines.append("for the tools to achieve the user's goal.\n")
prompt_lines.append("Conversation history so far:")
prompt_lines.append(conversation_history)
prompt_lines.append("")
# List all tools and their arguments
prompt_lines.append("Available tools and their required arguments:")
for tool in tools_data.tools:
prompt_lines.append(f"Tool to run: {tool.name}")
prompt_lines.append(f"Description: {tool.description}")
prompt_lines.append("Arguments needed:")
prompt_lines.append(f"- Tool name: {tool.name}")
prompt_lines.append(f" Description: {tool.description}")
prompt_lines.append(" Arguments needed:")
for arg in tool.arguments:
prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}")
prompt_lines.append("")
prompt_lines.append("Instructions:")
prompt_lines.append(
"1. You need to ask the user (or confirm with them) for each argument required by the tools above."
"1. You may call multiple tools in sequence if needed, each requiring certain arguments. "
"Ask the user for missing details when necessary. "
)
prompt_lines.append(
"2. If you do not yet have a specific argument value, ask the user for it."
"2. If you do not yet have a specific argument value, ask the user for it by setting 'next': 'question'."
)
prompt_lines.append(
"3. Once you have all arguments, read them back to confirm with the user before yielding to the tool to take action.\n"
"3. Once you have enough information for a particular tool, respond with 'next': 'confirm' and include the tool name in 'tool'."
)
prompt_lines.append(
'Your response must be valid JSON in the format: {"response": "<ai response>", "next": "<question|confirm>", '
+ '"tool": "<tool_name>", "arg1": "value1", "arg2": "value2"}" where args are the arguments for the tool (or null if unknown so far)."'
"4. If you have completed all necessary tools (no more actions needed), use 'next': 'done' in your JSON response ."
)
prompt_lines.append(
'- Your goal is to convert the AI responses into filled args in the JSON and once all args are filled, confirm with the user.".'
"5. Your response must be valid JSON in this format:\n"
" {\n"
' "response": "<plain text to user>",\n'
' "next": "<question|confirm|done>",\n'
' "tool": "<tool_name or none>",\n'
' "args": {\n'
' "<arg1>": "<value1>",\n'
' "<arg2>": "<value2>", ...\n'
" }\n"
" }\n"
" where 'args' are the arguments for the tool (or empty if not needed)."
)
prompt_lines.append(
'- If you still need information from the user, use "next": "question".'
"6. If you still need information from the user, use 'next': 'question'. "
"If you have enough info for a specific tool, use 'next': 'confirm'. "
"Do NOT use 'next': 'confirm' until you have all necessary arguments (i.e. they're NOT 'null') ."
"If you are finished with all tools, use 'next': 'done'."
)
prompt_lines.append(
'- If you have enough information and are confirming, use "next": "confirm". This is the final step once you have filled all args.'
"7. Keep responses in plain text. Return valid JSON without extra commentary."
)
prompt_lines.append(
'- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from <city> to <city> from <date> to <date>. Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateDepart": "2025-01-04", "dateReturn": "2025-01-08"}, "next": "confirm", "tool": "<toolName>" }'
)
prompt_lines.append("- Return valid JSON without special characters.")
prompt_lines.append("")
prompt_lines.append("Begin by prompting or confirming the necessary details.")
@@ -66,9 +80,10 @@ def generate_json_validation_prompt_from_tools_data(
"""
Generates a prompt instructing the AI to:
1. Check that the given raw JSON is syntactically valid.
2. Ensure the 'tool' matches one of the defined tools in tools_data.
3. Confirm or correct that all required arguments are present and make sense.
2. Ensure the 'tool' matches one of the defined tools or is 'none' if no tool is needed.
3. Confirm or correct that all required arguments are present or set to null if missing.
4. Return a corrected JSON if possible.
5. Accept 'next' as one of 'question', 'confirm', or 'done'.
"""
prompt_lines = []
@@ -78,8 +93,9 @@ def generate_json_validation_prompt_from_tools_data(
prompt_lines.append("It may be malformed or incomplete.")
prompt_lines.append("You also have a list of tools and their required arguments.")
prompt_lines.append(
"You must ensure the JSON is valid and matches these definitions.\n"
"You must ensure the JSON is valid and matches these definitions."
)
prompt_lines.append("")
prompt_lines.append("== Tools Definitions ==")
for tool in tools_data.tools:
@@ -97,34 +113,42 @@ def generate_json_validation_prompt_from_tools_data(
prompt_lines.append("Validation checks:")
prompt_lines.append("1. Is the JSON syntactically valid? If not, fix it.")
prompt_lines.append(
"2. Does the 'tool' field match one of the tools in Tools Definitions? If not, correct or note the mismatch."
"2. Does the 'tool' field match one of the tools above (or 'none')?"
)
prompt_lines.append(
"3. Do the arguments under 'args' correspond exactly to the required arguments for that tool? Are they present and valid? If not, set them to null or correct them."
"3. Do the 'args' correspond exactly to the required arguments for that tool? "
"If arguments are missing, set them to null or correct them if possible."
)
prompt_lines.append(
"4. Confirm the 'response' and 'next' fields are present, if applicable, per the desired JSON structure."
"4. Check the 'response' field is present. The user-facing text can be corrected but not removed."
)
prompt_lines.append(
"5. If something is missing or incorrect, fix it in the final JSON output or explain what is missing."
"5. 'next' should be one of 'question', 'confirm', or 'done' (if no more actions)."
"Do NOT use 'next': 'confirm' until you have all args. If there are any args that are null then next='question'). "
)
prompt_lines.append(
"6. You can and should take values from the response, parse them and insert them into JSON args where possible. Carefully parse the history and the latest response to fill in the args."
"Use the conversation history to parse known data for filling 'args' if possible. "
)
prompt_lines.append("")
prompt_lines.append(
"Return your response in valid JSON. DO NOT RETURN ANYTHING EXCEPT VALID JSON IN THE CORRECT FORMAT. No editorializing or comments on the JSON."
)
prompt_lines.append("The final output must:")
prompt_lines.append(
'- Provide the corrected JSON if you can fix it, using the format {"response": "...", "next": "...", "tool": "...", "args": {...}}.'
"Return only valid JSON in the format:\n"
"{\n"
' "response": "...",\n'
' "next": "question|confirm|done",\n'
' "tool": "<existing-tool-name-or-none>",\n'
' "args": { ... }\n'
"}"
)
prompt_lines.append(
'- If you cannot correct it then provide a skeleton JSON structure with the original "response" value inside.\n'
"No additional commentary or explanation. Just the corrected JSON. "
)
prompt_lines.append("")
prompt_lines.append("Conversation history so far:")
prompt_lines.append(conversation_history)
prompt_lines.append("Begin validating now.")
prompt_lines.append(
"\nIMPORTANT: ANALYZE THIS HISTORY TO DETERMINE WHICH ARGUMENTS TO PRE-FILL IN THE JSON RESPONSE. "
)
prompt_lines.append("")
prompt_lines.append("Begin validating now. ")
return "\n".join(prompt_lines)

View File

@@ -1,6 +1,7 @@
# send_message.py
import asyncio
import sys
from typing import List
from temporalio.client import Client
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
@@ -8,8 +9,26 @@ from models.tool_definitions import ToolDefinition, ToolArgument
from workflows.tool_workflow import ToolWorkflow
async def main(prompt):
# Construct your tool definitions in code
async def main(prompt: str):
# 1) Define the FindEvents tool
find_events_tool = ToolDefinition(
name="FindEvents",
description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month",
arguments=[
ToolArgument(
name="continent",
type="string",
description="Which continent or region to search for events",
),
ToolArgument(
name="month",
type="string",
description="The month or approximate date range to find events",
),
],
)
# 2) Define the SearchFlights tool
search_flights_tool = ToolDefinition(
name="SearchFlights",
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
@@ -37,25 +56,51 @@ async def main(prompt):
],
)
# Wrap it in ToolsData
tools_data = ToolsData(tools=[search_flights_tool])
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None), tools_data=tools_data
# 3) Define the CreateInvoice tool
create_invoice_tool = ToolDefinition(
name="CreateInvoice",
description="Generate an invoice with flight information or other items to purchase",
arguments=[
ToolArgument(
name="amount",
type="float",
description="The total cost to be invoiced",
),
ToolArgument(
name="flightDetails",
type="string",
description="A summary of the flights, e.g., flight numbers, price breakdown",
),
],
)
# Create client connected to Temporal server
# Collect all tools in a ToolsData structure
all_tools: List[ToolDefinition] = [
find_events_tool,
search_flights_tool,
create_invoice_tool,
]
tools_data = ToolsData(tools=all_tools)
# Create the combined input (includes ToolsData + optional conversation summary or prompt queue)
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None),
tools_data=tools_data,
)
# 4) Connect to Temporal and start or signal the workflow
client = await Client.connect("localhost:7233")
workflow_id = "ollama-agent"
# Start or signal the workflow, passing OllamaParams and tools_data
# Note that we start the ToolWorkflow.run with 'combined_input'
# Then we immediately signal with the initial prompt
await client.start_workflow(
ToolWorkflow.run,
combined_input, # or pass custom summary/prompt_queue
combined_input,
id=workflow_id,
task_queue="ollama-task-queue",
start_signal="user_prompt",
start_signal="user_prompt", # This will send your first prompt to the workflow
start_signal_args=[prompt],
)
@@ -63,6 +108,9 @@ async def main(prompt):
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python send_message.py '<prompt>'")
print("Example: python send_message.py 'What animals are marsupials?'")
print(
"Example: python send_message.py 'I want an event in Oceania this March'"
" or 'Search flights from Seattle to San Francisco'"
)
else:
asyncio.run(main(sys.argv[1]))

View File

@@ -1,13 +1,14 @@
from .find_events import find_events
from .search_flights import search_flights
from .create_invoice import create_invoice
def get_handler(tool_name: str):
"""
Return a function reference for the given tool.
You can add more tools here, e.g. "BookHotel", etc.
"""
if tool_name == "FindEvents":
return find_events
if tool_name == "SearchFlights":
return search_flights
if tool_name == "CreateInvoice":
return create_invoice
# Or raise if not recognized
raise ValueError(f"No handler found for tool '{tool_name}'")
raise ValueError(f"Unknown tool: {tool_name}")

7
tools/create_invoice.py Normal file
View File

@@ -0,0 +1,7 @@
def create_invoice(args: dict) -> dict:
# e.g. amount, flight details, etc.
print("[CreateInvoice] Creating invoice with:", args)
return {
"invoiceStatus": "generated",
"invoiceURL": "https://pay.example.com/invoice/12345",
}

17
tools/find_events.py Normal file
View File

@@ -0,0 +1,17 @@
def find_events(args: dict) -> dict:
# Example: continent="Oceania", month="April"
continent = args.get("continent")
month = args.get("month")
print(f"[FindEvents] Searching events in {continent} for {month} ...")
# Stub result
return {
"eventsFound": [
{
"city": "Melbourne",
"eventName": "Melbourne International Comedy Festival",
"dates": "2025-03-26 to 2025-04-20",
},
],
"status": "found-events",
}

View File

@@ -25,7 +25,6 @@ class ToolWorkflow:
@workflow.run
async def run(self, combined_input: CombinedInput) -> str:
params = combined_input.tool_params
tools_data = combined_input.tools_data
@@ -39,14 +38,36 @@ class ToolWorkflow:
self.prompt_queue.extend(params.prompt_queue)
while True:
# 1) Wait for a user prompt or an end-chat
await workflow.wait_condition(
lambda: bool(self.prompt_queue) or self.chat_ended
)
if self.prompt_queue:
# 1) Get the user prompt -> call initial LLM
if self.chat_ended:
# Possibly do a summary if multiple turns
if len(self.conversation_history) > 1:
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = ToolPromptInput(
prompt=summary_prompt,
context_instructions=summary_context,
)
self.conversation_summary = await workflow.start_activity_method(
ToolActivities.prompt_llm,
summary_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
workflow.logger.info(
"Chat ended. Conversation summary:\n"
+ f"{self.conversation_summary}"
)
return f"{self.conversation_history}"
# 2) Pop the users new message from the queue
prompt = self.prompt_queue.popleft()
self.conversation_history.append(("user", prompt))
# 3) Call the LLM with the entire conversation + Tools
context_instructions = generate_genai_prompt_from_tools_data(
tools_data, self.format_history()
)
@@ -60,7 +81,7 @@ class ToolWorkflow:
schedule_to_close_timeout=timedelta(seconds=20),
)
# 2) Validate + parse in one shot
# 4) Validate + parse in one shot
tool_data = await workflow.execute_activity_method(
ToolActivities.validate_and_parse_json,
args=[responsePrechecked, tools_data, self.format_history()],
@@ -68,37 +89,57 @@ class ToolWorkflow:
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=10)),
)
# store it
# 5) Store it and show the conversation
self.tool_data = tool_data
self.conversation_history.append(("response", str(tool_data)))
if self.tool_data.get("next") == "confirm":
# 6) Check for special flags
next_step = self.tool_data.get("next") # e.g. "confirm", "question", "done"
current_tool = self.tool_data.get(
"tool"
) # e.g. "FindEvents", "SearchFlights", "CreateInvoice"
if next_step == "confirm" and current_tool:
# We have enough info to call the tool
dynamic_result = await workflow.execute_activity(
self.tool_data["tool"], # dynamic activity name
self.tool_data["args"], # single argument to pass
current_tool,
self.tool_data["args"], # single argument
schedule_to_close_timeout=timedelta(seconds=20),
)
return dynamic_result
# Append tools result to the conversation
self.conversation_history.append(
(f"{current_tool}_result", str(dynamic_result))
)
# Continue as new after X turns
# Enqueue a follow-up question to the LLM
self.prompt_queue.append(
f"The '{current_tool}' tool completed successfully with {dynamic_result}. "
"INSTRUCTIONS: Use this tool result, and the context_instructions (conversation history) to intelligently pre-fill the next tool's arguments. "
"What should we do next? "
)
# The loop continues, and on the next iteration, the workflow sees that new "prompt"
# as if the user typed it, calls the LLM, etc.
elif next_step == "done":
# LLM signals no more tools needed
workflow.logger.info("All steps completed. Exiting workflow.")
return str(self.conversation_history)
# 7) Optionally handle "continue_as_new" after many turns
if len(self.conversation_history) >= self.max_turns_before_continue:
# Summarize conversation
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = ToolPromptInput(
prompt=summary_prompt,
context_instructions=summary_context,
)
self.conversation_summary = await workflow.start_activity_method(
ToolActivities.prompt_llm,
summary_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
workflow.logger.info(
"Continuing as new after %i turns."
% self.max_turns_before_continue,
f"Continuing as new after {self.max_turns_before_continue} turns."
)
workflow.continue_as_new(
@@ -113,6 +154,9 @@ class ToolWorkflow:
]
)
# 8) If "next_step" is "question" or anything else,
# we just keep looping, waiting for user prompt or signals.
continue
# Handle end of chat