diff --git a/README.md b/README.md index 8cd44e4..a1dc8ef 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ Note: I found the other (hosted) LLMs to be MUCH more reliable for this use case * Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env * It's free to sign up and get a key at [Stripe](https://stripe.com/) * If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file. +* Requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token. Set `FOOTBALL_DATA_API_KEY` to this value. ## Configuring Temporal Connection @@ -137,3 +138,11 @@ Access the UI at `http://localhost:5173` - Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution. - Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output) - Tests would be nice! + +# TODO for this branch +## Agent +- We'll have to figure out which matches are where. No use going to Manchester for a match that isn't there. + +## Validator function +- Probably keep data types, but move the activity and workflow code for the demo +- Probably don't need the validator function if its the result from a tool call or confirmation step \ No newline at end of file diff --git a/activities/tool_activities.py b/activities/tool_activities.py index 445791d..41441c7 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from temporalio import activity from ollama import chat, ChatResponse from openai import OpenAI @@ -11,6 +10,7 @@ import google.generativeai as genai import anthropic import deepseek from dotenv import load_dotenv +from models.data_types import ValidationInput, ValidationResult, ToolPromptInput load_dotenv(override=True) print( @@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama": print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME")) -@dataclass -class ToolPromptInput: - prompt: str - context_instructions: str - - class ToolActivities: + @activity.defn + async def validate_llm_prompt( + self, validation_input: ValidationInput + ) -> ValidationResult: + """ + Validates the prompt in the context of the conversation history and agent goal. + Returns a ValidationResult indicating if the prompt makes sense given the context. + """ + # Create simple context string describing tools and goals + tools_description = [] + for tool in validation_input.agent_goal.tools: + tool_str = f"Tool: {tool.name}\n" + tool_str += f"Description: {tool.description}\n" + tool_str += "Arguments: " + ", ".join( + [f"{arg.name} ({arg.type})" for arg in tool.arguments] + ) + tools_description.append(tool_str) + tools_str = "\n".join(tools_description) + + # Convert conversation history to string + history_str = json.dumps(validation_input.conversation_history, indent=2) + + # Create context instructions + context_instructions = f"""The agent goal and tools are as follows: + Description: {validation_input.agent_goal.description} + Available Tools: + {tools_str} + The conversation history to date is: + {history_str}""" + + # Create validation prompt + validation_prompt = f"""The user's prompt is: "{validation_input.prompt}" + Please validate if this prompt makes sense given the agent goal and conversation history. + If the prompt doesn't make sense toward the goal then validationResult should be true. + Only return false if the prompt is nonsensical given the goal, tools available, and conversation history. + Return ONLY a JSON object with the following structure: + "validationResult": true/false, + "validationFailedReason": "If validationResult is false, provide a clear explanation to the user + about why their request doesn't make sense in the context and what information they should provide instead. + validationFailedReason should contain JSON in the format + {{ + "next": "question", + "response": "[your reason here and a response to get the user back on track with the agent goal]" + }} + If validationResult is true, return an empty dict {{}}" + """ + + # Call the LLM with the validation prompt + prompt_input = ToolPromptInput( + prompt=validation_prompt, context_instructions=context_instructions + ) + + result = self.prompt_llm(prompt_input) + + return ValidationResult( + validationResult=result.get("validationResult", False), + validationFailedReason=result.get("validationFailedReason", {}), + ) + @activity.defn def prompt_llm(self, input: ToolPromptInput) -> dict: llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower() diff --git a/models/data_types.py b/models/data_types.py index a2de494..abe8799 100644 --- a/models/data_types.py +++ b/models/data_types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, Deque +from typing import Optional, Deque, Dict, Any, List, Union, Literal from models.tool_definitions import AgentGoal @@ -13,3 +13,32 @@ class ToolWorkflowParams: class CombinedInput: tool_params: ToolWorkflowParams agent_goal: AgentGoal + + +Message = Dict[str, Union[str, Dict[str, Any]]] +ConversationHistory = Dict[str, List[Message]] +NextStep = Literal["confirm", "question", "done"] + + +@dataclass +class ToolPromptInput: + prompt: str + context_instructions: str + + +@dataclass +class ValidationInput: + prompt: str + conversation_history: ConversationHistory + agent_goal: AgentGoal + + +@dataclass +class ValidationResult: + validationResult: bool + validationFailedReason: dict = None + + def __post_init__(self): + # Initialize empty dict if None + if self.validationFailedReason is None: + self.validationFailedReason = {} diff --git a/scripts/run_worker.py b/scripts/run_worker.py index 0f9bd5a..d83f547 100644 --- a/scripts/run_worker.py +++ b/scripts/run_worker.py @@ -24,6 +24,7 @@ async def main(): workflows=[ToolWorkflow], activities=[ activities.prompt_llm, + activities.validate_llm_prompt, dynamic_tool_activity, ], activity_executor=activity_executor, diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index 5f61149..93369d8 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -1,24 +1,21 @@ from collections import deque from datetime import timedelta -from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal +from typing import Dict, Any, Union, List, Optional, Deque, TypedDict from temporalio.common import RetryPolicy from temporalio import workflow +from models.data_types import ConversationHistory, NextStep, ValidationInput + with workflow.unsafe.imports_passed_through(): - from activities.tool_activities import ToolActivities, ToolPromptInput + from activities.tool_activities import ToolActivities from prompts.agent_prompt_generators import generate_genai_prompt - from models.data_types import CombinedInput, ToolWorkflowParams + from models.data_types import CombinedInput, ToolWorkflowParams, ToolPromptInput # Constants MAX_TURNS_BEFORE_CONTINUE = 250 TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20) -LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60) - -# Type definitions -Message = Dict[str, Union[str, Dict[str, Any]]] -ConversationHistory = Dict[str, List[Message]] -NextStep = Literal["confirm", "question", "done"] +LLM_ACTIVITY_TIMEOUT = timedelta(minutes=30) class ToolData(TypedDict, total=False): @@ -154,6 +151,26 @@ class ToolWorkflow: if not prompt.startswith("###"): self.add_message("user", prompt) + # Validate the prompt before proceeding + validation_input = ValidationInput( + prompt=prompt, + conversation_history=self.conversation_history, + agent_goal=agent_goal, + ) + validation_result = await workflow.execute_activity( + ToolActivities.validate_llm_prompt, + args=[validation_input], + schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT, + retry_policy=RetryPolicy(initial_interval=timedelta(seconds=5)), + ) + + if not validation_result.validationResult: + # Handle validation failure + self.add_message("agent", validation_result.validationFailedReason) + continue # Skip to the next iteration + + # Proceed with generating the context and prompt + context_instructions = generate_genai_prompt( agent_goal, self.conversation_history, self.tool_data )