diff --git a/README.md b/README.md index 8cd44e4..c5d9b77 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ Note: I found the other (hosted) LLMs to be MUCH more reliable for this use case * Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env * It's free to sign up and get a key at [Stripe](https://stripe.com/) * If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file. +* Requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token. Set `FOOTBALL_DATA_API_KEY` to this value. ## Configuring Temporal Connection @@ -137,3 +138,12 @@ Access the UI at `http://localhost:5173` - Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution. - Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output) - Tests would be nice! + +# TODO for this branch +## Agent +- We'll have to figure out which matches are where. No use going to Manchester for a match that isn't there. +- The use of `###` in prompts I want excluded from the conversation history is a bit of a hack. + +## Validator function +- Probably keep data types, but move the activity and workflow code for the demo +- Probably don't need the validator function if its the result from a tool call or confirmation step \ No newline at end of file diff --git a/activities/tool_activities.py b/activities/tool_activities.py index 445791d..41441c7 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from temporalio import activity from ollama import chat, ChatResponse from openai import OpenAI @@ -11,6 +10,7 @@ import google.generativeai as genai import anthropic import deepseek from dotenv import load_dotenv +from models.data_types import ValidationInput, ValidationResult, ToolPromptInput load_dotenv(override=True) print( @@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama": print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME")) -@dataclass -class ToolPromptInput: - prompt: str - context_instructions: str - - class ToolActivities: + @activity.defn + async def validate_llm_prompt( + self, validation_input: ValidationInput + ) -> ValidationResult: + """ + Validates the prompt in the context of the conversation history and agent goal. + Returns a ValidationResult indicating if the prompt makes sense given the context. + """ + # Create simple context string describing tools and goals + tools_description = [] + for tool in validation_input.agent_goal.tools: + tool_str = f"Tool: {tool.name}\n" + tool_str += f"Description: {tool.description}\n" + tool_str += "Arguments: " + ", ".join( + [f"{arg.name} ({arg.type})" for arg in tool.arguments] + ) + tools_description.append(tool_str) + tools_str = "\n".join(tools_description) + + # Convert conversation history to string + history_str = json.dumps(validation_input.conversation_history, indent=2) + + # Create context instructions + context_instructions = f"""The agent goal and tools are as follows: + Description: {validation_input.agent_goal.description} + Available Tools: + {tools_str} + The conversation history to date is: + {history_str}""" + + # Create validation prompt + validation_prompt = f"""The user's prompt is: "{validation_input.prompt}" + Please validate if this prompt makes sense given the agent goal and conversation history. + If the prompt doesn't make sense toward the goal then validationResult should be true. + Only return false if the prompt is nonsensical given the goal, tools available, and conversation history. + Return ONLY a JSON object with the following structure: + "validationResult": true/false, + "validationFailedReason": "If validationResult is false, provide a clear explanation to the user + about why their request doesn't make sense in the context and what information they should provide instead. + validationFailedReason should contain JSON in the format + {{ + "next": "question", + "response": "[your reason here and a response to get the user back on track with the agent goal]" + }} + If validationResult is true, return an empty dict {{}}" + """ + + # Call the LLM with the validation prompt + prompt_input = ToolPromptInput( + prompt=validation_prompt, context_instructions=context_instructions + ) + + result = self.prompt_llm(prompt_input) + + return ValidationResult( + validationResult=result.get("validationResult", False), + validationFailedReason=result.get("validationFailedReason", {}), + ) + @activity.defn def prompt_llm(self, input: ToolPromptInput) -> dict: llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower() diff --git a/api/main.py b/api/main.py index 252b714..9dfff06 100644 --- a/api/main.py +++ b/api/main.py @@ -132,3 +132,26 @@ async def end_chat(): print(e) # Workflow not found; return an empty response return {} + + +@app.post("/start-workflow") +async def start_workflow(): + # Create combined input + combined_input = CombinedInput( + tool_params=ToolWorkflowParams(None, None), + agent_goal=goal_event_flight_invoice, + ) + + workflow_id = "agent-workflow" + + # Start the workflow with the starter prompt from the goal + await temporal_client.start_workflow( + ToolWorkflow.run, + combined_input, + id=workflow_id, + task_queue=TEMPORAL_TASK_QUEUE, + start_signal="user_prompt", + start_signal_args=["### " + goal_event_flight_invoice.starter_prompt], + ) + + return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."} diff --git a/frontend/src/pages/App.jsx b/frontend/src/pages/App.jsx index 94ef16a..4fde2ae 100644 --- a/frontend/src/pages/App.jsx +++ b/frontend/src/pages/App.jsx @@ -167,7 +167,7 @@ export default function App() { try { setError(INITIAL_ERROR_STATE); setLoading(true); - await apiService.sendMessage("I'd like to travel for an event."); + await apiService.startWorkflow(); setConversation([]); setLastMessage(null); } catch (err) { diff --git a/frontend/src/services/api.js b/frontend/src/services/api.js index af93d49..02bdd73 100644 --- a/frontend/src/services/api.js +++ b/frontend/src/services/api.js @@ -56,6 +56,26 @@ export const apiService = { } }, + async startWorkflow() { + try { + const res = await fetch( + `${API_BASE_URL}/start-workflow`, + { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + } + } + ); + return handleResponse(res); + } catch (error) { + throw new ApiError( + 'Failed to start workflow', + error.status || 500 + ); + } + }, + async confirm() { try { const res = await fetch(`${API_BASE_URL}/confirm`, { diff --git a/models/data_types.py b/models/data_types.py index a2de494..abe8799 100644 --- a/models/data_types.py +++ b/models/data_types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, Deque +from typing import Optional, Deque, Dict, Any, List, Union, Literal from models.tool_definitions import AgentGoal @@ -13,3 +13,32 @@ class ToolWorkflowParams: class CombinedInput: tool_params: ToolWorkflowParams agent_goal: AgentGoal + + +Message = Dict[str, Union[str, Dict[str, Any]]] +ConversationHistory = Dict[str, List[Message]] +NextStep = Literal["confirm", "question", "done"] + + +@dataclass +class ToolPromptInput: + prompt: str + context_instructions: str + + +@dataclass +class ValidationInput: + prompt: str + conversation_history: ConversationHistory + agent_goal: AgentGoal + + +@dataclass +class ValidationResult: + validationResult: bool + validationFailedReason: dict = None + + def __post_init__(self): + # Initialize empty dict if None + if self.validationFailedReason is None: + self.validationFailedReason = {} diff --git a/models/tool_definitions.py b/models/tool_definitions.py index 6e711b5..d4b085f 100644 --- a/models/tool_definitions.py +++ b/models/tool_definitions.py @@ -20,6 +20,7 @@ class ToolDefinition: class AgentGoal: tools: List[ToolDefinition] description: str = "Description of the tools purpose and overall goal" + starter_prompt: str = "Initial prompt to start the conversation" example_conversation_history: str = ( "Example conversation history to help the AI agent understand the context of the conversation" ) diff --git a/scripts/run_worker.py b/scripts/run_worker.py index 0f9bd5a..d83f547 100644 --- a/scripts/run_worker.py +++ b/scripts/run_worker.py @@ -24,6 +24,7 @@ async def main(): workflows=[ToolWorkflow], activities=[ activities.prompt_llm, + activities.validate_llm_prompt, dynamic_tool_activity, ], activity_executor=activity_executor, diff --git a/tools/goal_registry.py b/tools/goal_registry.py index 51d117a..983afcc 100644 --- a/tools/goal_registry.py +++ b/tools/goal_registry.py @@ -11,6 +11,7 @@ goal_event_flight_invoice = AgentGoal( "1. FindFixtures: Find fixtures for a team in a given month " "2. SearchFlights: search for a flight around the event dates " "3. CreateInvoice: Create a simple invoice for the cost of that flight ", + starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job", example_conversation_history="\n ".join( [ "user: I'd like to travel to a football match", diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index 5f61149..d321df3 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -1,24 +1,21 @@ from collections import deque from datetime import timedelta -from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal +from typing import Dict, Any, Union, List, Optional, Deque, TypedDict from temporalio.common import RetryPolicy from temporalio import workflow +from models.data_types import ConversationHistory, NextStep, ValidationInput + with workflow.unsafe.imports_passed_through(): - from activities.tool_activities import ToolActivities, ToolPromptInput + from activities.tool_activities import ToolActivities from prompts.agent_prompt_generators import generate_genai_prompt - from models.data_types import CombinedInput, ToolWorkflowParams + from models.data_types import CombinedInput, ToolWorkflowParams, ToolPromptInput # Constants MAX_TURNS_BEFORE_CONTINUE = 250 TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20) -LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60) - -# Type definitions -Message = Dict[str, Union[str, Dict[str, Any]]] -ConversationHistory = Dict[str, List[Message]] -NextStep = Literal["confirm", "question", "done"] +LLM_ACTIVITY_TIMEOUT = timedelta(minutes=30) class ToolData(TypedDict, total=False): @@ -153,6 +150,26 @@ class ToolWorkflow: prompt = self.prompt_queue.popleft() if not prompt.startswith("###"): self.add_message("user", prompt) + + # Validate the prompt before proceeding + validation_input = ValidationInput( + prompt=prompt, + conversation_history=self.conversation_history, + agent_goal=agent_goal, + ) + validation_result = await workflow.execute_activity( + ToolActivities.validate_llm_prompt, + args=[validation_input], + schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT, + retry_policy=RetryPolicy(initial_interval=timedelta(seconds=5)), + ) + + if not validation_result.validationResult: + # Handle validation failure + self.add_message("agent", validation_result.validationFailedReason) + continue # Skip to the next iteration + + # Proceed with generating the context and prompt context_instructions = generate_genai_prompt( agent_goal, self.conversation_history, self.tool_data