diff --git a/activities/tool_activities.py b/activities/tool_activities.py index 28a8ecf..cc3f753 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -102,16 +102,16 @@ def get_current_date_human_readable(): @activity.defn(dynamic=True) def dynamic_tool_activity(args: Sequence[RawValue]) -> dict: - """Invoked for an unknown activity type, delegates to the correct tool.""" - from tools import get_handler # import the registry function + from tools import get_handler - tool_name = activity.info().activity_type # e.g. "SearchFlights" + tool_name = activity.info().activity_type # e.g. "FindEvents" tool_args = activity.payload_converter().from_payload(args[0].payload, dict) + activity.logger.info(f"Running dynamic tool '{tool_name}' with args: {tool_args}") - activity.logger.info(f"Dynamic activity triggered for tool: {tool_name}") - handler_func = get_handler(tool_name) - - # Delegate to the tool's function - result = handler_func(tool_args) + # Delegate to the relevant function + handler = get_handler(tool_name) + result = handler(tool_args) + # Optionally log or augment the result + activity.logger.info(f"Tool '{tool_name}' result: {result}") return result diff --git a/models/tool_definitions.py b/models/tool_definitions.py index 79fbfa6..739b77c 100644 --- a/models/tool_definitions.py +++ b/models/tool_definitions.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass -from typing import List +from dataclasses import dataclass, field +from typing import List, Dict, Any @dataclass @@ -19,3 +19,14 @@ class ToolDefinition: @dataclass class ToolsData: tools: List[ToolDefinition] + + +@dataclass +class ToolInvocation: + tool: str + args: Dict[str, Any] + + +@dataclass +class MultiToolSequence: + tool_invocations: List[ToolInvocation] = field(default_factory=list) diff --git a/prompts/agent_prompt_generators.py b/prompts/agent_prompt_generators.py index 013cc85..7585e43 100644 --- a/prompts/agent_prompt_generators.py +++ b/prompts/agent_prompt_generators.py @@ -6,54 +6,68 @@ def generate_genai_prompt_from_tools_data( ) -> str: """ Generates a prompt describing the tools and the instructions for the AI - assistant, using the conversation history provided. + assistant, using the conversation history provided, allowing for multiple + tools and a 'done' state. """ prompt_lines = [] prompt_lines.append( - "You are an AI assistant that must determine all required arguments" + "You are an AI assistant that must determine all required arguments " + "for the tools to achieve the user's goal. " + ) + prompt_lines.append("") + prompt_lines.append( + "Conversation history so far. \nANALYZE THIS HISTORY TO DETERMINE WHICH ARGUMENTS TO PRE-FILL AS SPECIFIED FOR THE TOOL BELOW: " ) - prompt_lines.append("for the tools to achieve the user's goal.\n") - prompt_lines.append("Conversation history so far:") prompt_lines.append(conversation_history) prompt_lines.append("") # List all tools and their arguments + prompt_lines.append("Available tools and their required arguments:") for tool in tools_data.tools: - prompt_lines.append(f"Tool to run: {tool.name}") - prompt_lines.append(f"Description: {tool.description}") - prompt_lines.append("Arguments needed:") + prompt_lines.append(f"- Tool name: {tool.name}") + prompt_lines.append(f" Description: {tool.description}") + prompt_lines.append(" Arguments needed:") for arg in tool.arguments: - prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}") + prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}") prompt_lines.append("") prompt_lines.append("Instructions:") prompt_lines.append( - "1. You need to ask the user (or confirm with them) for each argument required by the tools above." + "1. You may call multiple tools in sequence if needed, each requiring certain arguments. " + "Ask the user for missing details when necessary. " ) prompt_lines.append( - "2. If you do not yet have a specific argument value, ask the user for it." + "2. If you do not yet have a specific argument value, ask the user for it by setting 'next': 'question'." ) prompt_lines.append( - "3. Once you have all arguments, read them back to confirm with the user before yielding to the tool to take action.\n" + "3. Once you have enough information for a particular tool, respond with 'next': 'confirm' and include the tool name in 'tool'." ) prompt_lines.append( - 'Your response must be valid JSON in the format: {"response": "", "next": "", ' - + '"tool": "", "arg1": "value1", "arg2": "value2"}" where args are the arguments for the tool (or null if unknown so far)."' + "4. If you have completed all necessary tools (no more actions needed), use 'next': 'done' in your JSON response ." ) prompt_lines.append( - '- Your goal is to convert the AI responses into filled args in the JSON and once all args are filled, confirm with the user.".' + "5. Your response must be valid JSON in this format:\n" + " {\n" + ' "response": "",\n' + ' "next": "",\n' + ' "tool": "",\n' + ' "args": {\n' + ' "": "",\n' + ' "": "", ...\n' + " }\n" + " }\n" + " where 'args' are the arguments for the tool (or empty if not needed)." ) prompt_lines.append( - '- If you still need information from the user, use "next": "question".' + "6. If you still need information from the user, use 'next': 'question'. " + "If you have enough info for a specific tool, use 'next': 'confirm'. " + "Do NOT use 'next': 'confirm' until you have all necessary arguments (i.e. they're NOT 'null') ." + "If you are finished with all tools, use 'next': 'done'." ) prompt_lines.append( - '- If you have enough information and are confirming, use "next": "confirm". This is the final step once you have filled all args.' + "7. Keep responses in plain text. Return valid JSON without extra commentary." ) - prompt_lines.append( - '- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from to from to . Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateDepart": "2025-01-04", "dateReturn": "2025-01-08"}, "next": "confirm", "tool": "" }' - ) - prompt_lines.append("- Return valid JSON without special characters.") prompt_lines.append("") prompt_lines.append("Begin by prompting or confirming the necessary details.") @@ -66,9 +80,10 @@ def generate_json_validation_prompt_from_tools_data( """ Generates a prompt instructing the AI to: 1. Check that the given raw JSON is syntactically valid. - 2. Ensure the 'tool' matches one of the defined tools in tools_data. - 3. Confirm or correct that all required arguments are present and make sense. + 2. Ensure the 'tool' matches one of the defined tools or is 'none' if no tool is needed. + 3. Confirm or correct that all required arguments are present or set to null if missing. 4. Return a corrected JSON if possible. + 5. Accept 'next' as one of 'question', 'confirm', or 'done'. """ prompt_lines = [] @@ -78,8 +93,9 @@ def generate_json_validation_prompt_from_tools_data( prompt_lines.append("It may be malformed or incomplete.") prompt_lines.append("You also have a list of tools and their required arguments.") prompt_lines.append( - "You must ensure the JSON is valid and matches these definitions.\n" + "You must ensure the JSON is valid and matches these definitions." ) + prompt_lines.append("") prompt_lines.append("== Tools Definitions ==") for tool in tools_data.tools: @@ -97,34 +113,42 @@ def generate_json_validation_prompt_from_tools_data( prompt_lines.append("Validation checks:") prompt_lines.append("1. Is the JSON syntactically valid? If not, fix it.") prompt_lines.append( - "2. Does the 'tool' field match one of the tools in Tools Definitions? If not, correct or note the mismatch." + "2. Does the 'tool' field match one of the tools above (or 'none')?" ) prompt_lines.append( - "3. Do the arguments under 'args' correspond exactly to the required arguments for that tool? Are they present and valid? If not, set them to null or correct them." + "3. Do the 'args' correspond exactly to the required arguments for that tool? " + "If arguments are missing, set them to null or correct them if possible." ) prompt_lines.append( - "4. Confirm the 'response' and 'next' fields are present, if applicable, per the desired JSON structure." + "4. Check the 'response' field is present. The user-facing text can be corrected but not removed." ) prompt_lines.append( - "5. If something is missing or incorrect, fix it in the final JSON output or explain what is missing." + "5. 'next' should be one of 'question', 'confirm', or 'done' (if no more actions)." + "Do NOT use 'next': 'confirm' until you have all args. If there are any args that are null then next='question'). " ) prompt_lines.append( - "6. You can and should take values from the response, parse them and insert them into JSON args where possible. Carefully parse the history and the latest response to fill in the args." + "Use the conversation history to parse known data for filling 'args' if possible. " ) prompt_lines.append("") prompt_lines.append( - "Return your response in valid JSON. DO NOT RETURN ANYTHING EXCEPT VALID JSON IN THE CORRECT FORMAT. No editorializing or comments on the JSON." - ) - prompt_lines.append("The final output must:") - prompt_lines.append( - '- Provide the corrected JSON if you can fix it, using the format {"response": "...", "next": "...", "tool": "...", "args": {...}}.' + "Return only valid JSON in the format:\n" + "{\n" + ' "response": "...",\n' + ' "next": "question|confirm|done",\n' + ' "tool": "",\n' + ' "args": { ... }\n' + "}" ) prompt_lines.append( - '- If you cannot correct it then provide a skeleton JSON structure with the original "response" value inside.\n' + "No additional commentary or explanation. Just the corrected JSON. " ) + prompt_lines.append("") prompt_lines.append("Conversation history so far:") prompt_lines.append(conversation_history) - - prompt_lines.append("Begin validating now.") + prompt_lines.append( + "\nIMPORTANT: ANALYZE THIS HISTORY TO DETERMINE WHICH ARGUMENTS TO PRE-FILL IN THE JSON RESPONSE. " + ) + prompt_lines.append("") + prompt_lines.append("Begin validating now. ") return "\n".join(prompt_lines) diff --git a/scripts/send_message.py b/scripts/send_message.py index 41252c7..baaf8e9 100644 --- a/scripts/send_message.py +++ b/scripts/send_message.py @@ -1,6 +1,7 @@ +# send_message.py import asyncio import sys - +from typing import List from temporalio.client import Client from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams @@ -8,8 +9,26 @@ from models.tool_definitions import ToolDefinition, ToolArgument from workflows.tool_workflow import ToolWorkflow -async def main(prompt): - # Construct your tool definitions in code +async def main(prompt: str): + # 1) Define the FindEvents tool + find_events_tool = ToolDefinition( + name="FindEvents", + description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month", + arguments=[ + ToolArgument( + name="continent", + type="string", + description="Which continent or region to search for events", + ), + ToolArgument( + name="month", + type="string", + description="The month or approximate date range to find events", + ), + ], + ) + + # 2) Define the SearchFlights tool search_flights_tool = ToolDefinition( name="SearchFlights", description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)", @@ -37,25 +56,51 @@ async def main(prompt): ], ) - # Wrap it in ToolsData - tools_data = ToolsData(tools=[search_flights_tool]) - - combined_input = CombinedInput( - tool_params=ToolWorkflowParams(None, None), tools_data=tools_data + # 3) Define the CreateInvoice tool + create_invoice_tool = ToolDefinition( + name="CreateInvoice", + description="Generate an invoice with flight information or other items to purchase", + arguments=[ + ToolArgument( + name="amount", + type="float", + description="The total cost to be invoiced", + ), + ToolArgument( + name="flightDetails", + type="string", + description="A summary of the flights, e.g., flight numbers, price breakdown", + ), + ], ) - # Create client connected to Temporal server + # Collect all tools in a ToolsData structure + all_tools: List[ToolDefinition] = [ + find_events_tool, + search_flights_tool, + create_invoice_tool, + ] + tools_data = ToolsData(tools=all_tools) + + # Create the combined input (includes ToolsData + optional conversation summary or prompt queue) + combined_input = CombinedInput( + tool_params=ToolWorkflowParams(None, None), + tools_data=tools_data, + ) + + # 4) Connect to Temporal and start or signal the workflow client = await Client.connect("localhost:7233") workflow_id = "ollama-agent" - # Start or signal the workflow, passing OllamaParams and tools_data + # Note that we start the ToolWorkflow.run with 'combined_input' + # Then we immediately signal with the initial prompt await client.start_workflow( ToolWorkflow.run, - combined_input, # or pass custom summary/prompt_queue + combined_input, id=workflow_id, task_queue="ollama-task-queue", - start_signal="user_prompt", + start_signal="user_prompt", # This will send your first prompt to the workflow start_signal_args=[prompt], ) @@ -63,6 +108,9 @@ async def main(prompt): if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: python send_message.py ''") - print("Example: python send_message.py 'What animals are marsupials?'") + print( + "Example: python send_message.py 'I want an event in Oceania this March'" + " or 'Search flights from Seattle to San Francisco'" + ) else: asyncio.run(main(sys.argv[1])) diff --git a/tools/__init__.py b/tools/__init__.py index caf3c55..037e45d 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -1,13 +1,14 @@ +from .find_events import find_events from .search_flights import search_flights +from .create_invoice import create_invoice def get_handler(tool_name: str): - """ - Return a function reference for the given tool. - You can add more tools here, e.g. "BookHotel", etc. - """ + if tool_name == "FindEvents": + return find_events if tool_name == "SearchFlights": return search_flights + if tool_name == "CreateInvoice": + return create_invoice - # Or raise if not recognized - raise ValueError(f"No handler found for tool '{tool_name}'") + raise ValueError(f"Unknown tool: {tool_name}") diff --git a/tools/create_invoice.py b/tools/create_invoice.py new file mode 100644 index 0000000..f0083d6 --- /dev/null +++ b/tools/create_invoice.py @@ -0,0 +1,7 @@ +def create_invoice(args: dict) -> dict: + # e.g. amount, flight details, etc. + print("[CreateInvoice] Creating invoice with:", args) + return { + "invoiceStatus": "generated", + "invoiceURL": "https://pay.example.com/invoice/12345", + } diff --git a/tools/find_events.py b/tools/find_events.py new file mode 100644 index 0000000..2d9d563 --- /dev/null +++ b/tools/find_events.py @@ -0,0 +1,17 @@ +def find_events(args: dict) -> dict: + # Example: continent="Oceania", month="April" + continent = args.get("continent") + month = args.get("month") + print(f"[FindEvents] Searching events in {continent} for {month} ...") + + # Stub result + return { + "eventsFound": [ + { + "city": "Melbourne", + "eventName": "Melbourne International Comedy Festival", + "dates": "2025-03-26 to 2025-04-20", + }, + ], + "status": "found-events", + } diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index 1266ee2..ee7f7d2 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -25,7 +25,6 @@ class ToolWorkflow: @workflow.run async def run(self, combined_input: CombinedInput) -> str: - params = combined_input.tool_params tools_data = combined_input.tools_data @@ -39,79 +38,124 @@ class ToolWorkflow: self.prompt_queue.extend(params.prompt_queue) while True: + # 1) Wait for a user prompt or an end-chat await workflow.wait_condition( lambda: bool(self.prompt_queue) or self.chat_ended ) - if self.prompt_queue: - # 1) Get the user prompt -> call initial LLM - prompt = self.prompt_queue.popleft() - self.conversation_history.append(("user", prompt)) - context_instructions = generate_genai_prompt_from_tools_data( - tools_data, self.format_history() - ) - prompt_input = ToolPromptInput( - prompt=prompt, - context_instructions=context_instructions, - ) - responsePrechecked = await workflow.execute_activity_method( - ToolActivities.prompt_llm, - prompt_input, - schedule_to_close_timeout=timedelta(seconds=20), - ) - - # 2) Validate + parse in one shot - tool_data = await workflow.execute_activity_method( - ToolActivities.validate_and_parse_json, - args=[responsePrechecked, tools_data, self.format_history()], - schedule_to_close_timeout=timedelta(seconds=40), - retry_policy=RetryPolicy(initial_interval=timedelta(seconds=10)), - ) - - # store it - self.tool_data = tool_data - self.conversation_history.append(("response", str(tool_data))) - - if self.tool_data.get("next") == "confirm": - dynamic_result = await workflow.execute_activity( - self.tool_data["tool"], # dynamic activity name - self.tool_data["args"], # single argument to pass - schedule_to_close_timeout=timedelta(seconds=20), - ) - - return dynamic_result - - # Continue as new after X turns - if len(self.conversation_history) >= self.max_turns_before_continue: - # Summarize conversation + if self.chat_ended: + # Possibly do a summary if multiple turns + if len(self.conversation_history) > 1: summary_context, summary_prompt = self.prompt_summary_with_history() summary_input = ToolPromptInput( prompt=summary_prompt, context_instructions=summary_context, ) - self.conversation_summary = await workflow.start_activity_method( ToolActivities.prompt_llm, summary_input, schedule_to_close_timeout=timedelta(seconds=20), ) - workflow.logger.info( - "Continuing as new after %i turns." - % self.max_turns_before_continue, - ) + workflow.logger.info( + "Chat ended. Conversation summary:\n" + + f"{self.conversation_summary}" + ) + return f"{self.conversation_history}" - workflow.continue_as_new( - args=[ - CombinedInput( - tool_params=ToolWorkflowParams( - conversation_summary=self.conversation_summary, - prompt_queue=self.prompt_queue, - ), - tools_data=tools_data, - ) - ] - ) + # 2) Pop the user’s new message from the queue + prompt = self.prompt_queue.popleft() + self.conversation_history.append(("user", prompt)) + + # 3) Call the LLM with the entire conversation + Tools + context_instructions = generate_genai_prompt_from_tools_data( + tools_data, self.format_history() + ) + prompt_input = ToolPromptInput( + prompt=prompt, + context_instructions=context_instructions, + ) + responsePrechecked = await workflow.execute_activity_method( + ToolActivities.prompt_llm, + prompt_input, + schedule_to_close_timeout=timedelta(seconds=20), + ) + + # 4) Validate + parse in one shot + tool_data = await workflow.execute_activity_method( + ToolActivities.validate_and_parse_json, + args=[responsePrechecked, tools_data, self.format_history()], + schedule_to_close_timeout=timedelta(seconds=40), + retry_policy=RetryPolicy(initial_interval=timedelta(seconds=10)), + ) + + # 5) Store it and show the conversation + self.tool_data = tool_data + self.conversation_history.append(("response", str(tool_data))) + + # 6) Check for special flags + next_step = self.tool_data.get("next") # e.g. "confirm", "question", "done" + current_tool = self.tool_data.get( + "tool" + ) # e.g. "FindEvents", "SearchFlights", "CreateInvoice" + + if next_step == "confirm" and current_tool: + # We have enough info to call the tool + dynamic_result = await workflow.execute_activity( + current_tool, + self.tool_data["args"], # single argument + schedule_to_close_timeout=timedelta(seconds=20), + ) + + # Append tool’s result to the conversation + self.conversation_history.append( + (f"{current_tool}_result", str(dynamic_result)) + ) + + # Enqueue a follow-up question to the LLM + self.prompt_queue.append( + f"The '{current_tool}' tool completed successfully with {dynamic_result}. " + "INSTRUCTIONS: Use this tool result, and the context_instructions (conversation history) to intelligently pre-fill the next tool's arguments. " + "What should we do next? " + ) + # The loop continues, and on the next iteration, the workflow sees that new "prompt" + # as if the user typed it, calls the LLM, etc. + + elif next_step == "done": + # LLM signals no more tools needed + workflow.logger.info("All steps completed. Exiting workflow.") + return str(self.conversation_history) + + # 7) Optionally handle "continue_as_new" after many turns + if len(self.conversation_history) >= self.max_turns_before_continue: + summary_context, summary_prompt = self.prompt_summary_with_history() + summary_input = ToolPromptInput( + prompt=summary_prompt, + context_instructions=summary_context, + ) + self.conversation_summary = await workflow.start_activity_method( + ToolActivities.prompt_llm, + summary_input, + schedule_to_close_timeout=timedelta(seconds=20), + ) + workflow.logger.info( + f"Continuing as new after {self.max_turns_before_continue} turns." + ) + + workflow.continue_as_new( + args=[ + CombinedInput( + tool_params=ToolWorkflowParams( + conversation_summary=self.conversation_summary, + prompt_queue=self.prompt_queue, + ), + tools_data=tools_data, + ) + ] + ) + + # 8) If "next_step" is "question" or anything else, + # we just keep looping, waiting for user prompt or signals. continue