Merge pull request #5 from steveandroulakis/validator-and-improvements

Validator and improvements
This commit is contained in:
Steve Androulakis
2025-02-03 13:35:02 -08:00
committed by GitHub
10 changed files with 173 additions and 18 deletions

View File

@@ -68,6 +68,7 @@ Note: I found the other (hosted) LLMs to be MUCH more reliable for this use case
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
* It's free to sign up and get a key at [Stripe](https://stripe.com/)
* If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file.
* Requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token. Set `FOOTBALL_DATA_API_KEY` to this value.
## Configuring Temporal Connection
@@ -137,3 +138,12 @@ Access the UI at `http://localhost:5173`
- Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution.
- Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output)
- Tests would be nice!
# TODO for this branch
## Agent
- We'll have to figure out which matches are where. No use going to Manchester for a match that isn't there.
- The use of `###` in prompts I want excluded from the conversation history is a bit of a hack.
## Validator function
- Probably keep data types, but move the activity and workflow code for the demo
- Probably don't need the validator function if its the result from a tool call or confirmation step

View File

@@ -1,4 +1,3 @@
from dataclasses import dataclass
from temporalio import activity
from ollama import chat, ChatResponse
from openai import OpenAI
@@ -11,6 +10,7 @@ import google.generativeai as genai
import anthropic
import deepseek
from dotenv import load_dotenv
from models.data_types import ValidationInput, ValidationResult, ToolPromptInput
load_dotenv(override=True)
print(
@@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama":
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
@dataclass
class ToolPromptInput:
prompt: str
context_instructions: str
class ToolActivities:
@activity.defn
async def validate_llm_prompt(
self, validation_input: ValidationInput
) -> ValidationResult:
"""
Validates the prompt in the context of the conversation history and agent goal.
Returns a ValidationResult indicating if the prompt makes sense given the context.
"""
# Create simple context string describing tools and goals
tools_description = []
for tool in validation_input.agent_goal.tools:
tool_str = f"Tool: {tool.name}\n"
tool_str += f"Description: {tool.description}\n"
tool_str += "Arguments: " + ", ".join(
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
)
tools_description.append(tool_str)
tools_str = "\n".join(tools_description)
# Convert conversation history to string
history_str = json.dumps(validation_input.conversation_history, indent=2)
# Create context instructions
context_instructions = f"""The agent goal and tools are as follows:
Description: {validation_input.agent_goal.description}
Available Tools:
{tools_str}
The conversation history to date is:
{history_str}"""
# Create validation prompt
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
Please validate if this prompt makes sense given the agent goal and conversation history.
If the prompt doesn't make sense toward the goal then validationResult should be true.
Only return false if the prompt is nonsensical given the goal, tools available, and conversation history.
Return ONLY a JSON object with the following structure:
"validationResult": true/false,
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user
about why their request doesn't make sense in the context and what information they should provide instead.
validationFailedReason should contain JSON in the format
{{
"next": "question",
"response": "[your reason here and a response to get the user back on track with the agent goal]"
}}
If validationResult is true, return an empty dict {{}}"
"""
# Call the LLM with the validation prompt
prompt_input = ToolPromptInput(
prompt=validation_prompt, context_instructions=context_instructions
)
result = self.prompt_llm(prompt_input)
return ValidationResult(
validationResult=result.get("validationResult", False),
validationFailedReason=result.get("validationFailedReason", {}),
)
@activity.defn
def prompt_llm(self, input: ToolPromptInput) -> dict:
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()

View File

@@ -132,3 +132,26 @@ async def end_chat():
print(e)
# Workflow not found; return an empty response
return {}
@app.post("/start-workflow")
async def start_workflow():
# Create combined input
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None),
agent_goal=goal_event_flight_invoice,
)
workflow_id = "agent-workflow"
# Start the workflow with the starter prompt from the goal
await temporal_client.start_workflow(
ToolWorkflow.run,
combined_input,
id=workflow_id,
task_queue=TEMPORAL_TASK_QUEUE,
start_signal="user_prompt",
start_signal_args=["### " + goal_event_flight_invoice.starter_prompt],
)
return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."}

View File

@@ -167,7 +167,7 @@ export default function App() {
try {
setError(INITIAL_ERROR_STATE);
setLoading(true);
await apiService.sendMessage("I'd like to travel for an event.");
await apiService.startWorkflow();
setConversation([]);
setLastMessage(null);
} catch (err) {

View File

@@ -56,6 +56,26 @@ export const apiService = {
}
},
async startWorkflow() {
try {
const res = await fetch(
`${API_BASE_URL}/start-workflow`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json'
}
}
);
return handleResponse(res);
} catch (error) {
throw new ApiError(
'Failed to start workflow',
error.status || 500
);
}
},
async confirm() {
try {
const res = await fetch(`${API_BASE_URL}/confirm`, {

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, Deque
from typing import Optional, Deque, Dict, Any, List, Union, Literal
from models.tool_definitions import AgentGoal
@@ -13,3 +13,32 @@ class ToolWorkflowParams:
class CombinedInput:
tool_params: ToolWorkflowParams
agent_goal: AgentGoal
Message = Dict[str, Union[str, Dict[str, Any]]]
ConversationHistory = Dict[str, List[Message]]
NextStep = Literal["confirm", "question", "done"]
@dataclass
class ToolPromptInput:
prompt: str
context_instructions: str
@dataclass
class ValidationInput:
prompt: str
conversation_history: ConversationHistory
agent_goal: AgentGoal
@dataclass
class ValidationResult:
validationResult: bool
validationFailedReason: dict = None
def __post_init__(self):
# Initialize empty dict if None
if self.validationFailedReason is None:
self.validationFailedReason = {}

View File

@@ -20,6 +20,7 @@ class ToolDefinition:
class AgentGoal:
tools: List[ToolDefinition]
description: str = "Description of the tools purpose and overall goal"
starter_prompt: str = "Initial prompt to start the conversation"
example_conversation_history: str = (
"Example conversation history to help the AI agent understand the context of the conversation"
)

View File

@@ -24,6 +24,7 @@ async def main():
workflows=[ToolWorkflow],
activities=[
activities.prompt_llm,
activities.validate_llm_prompt,
dynamic_tool_activity,
],
activity_executor=activity_executor,

View File

@@ -11,6 +11,7 @@ goal_event_flight_invoice = AgentGoal(
"1. FindFixtures: Find fixtures for a team in a given month "
"2. SearchFlights: search for a flight around the event dates "
"3. CreateInvoice: Create a simple invoice for the cost of that flight ",
starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job",
example_conversation_history="\n ".join(
[
"user: I'd like to travel to a football match",

View File

@@ -1,24 +1,21 @@
from collections import deque
from datetime import timedelta
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
from temporalio.common import RetryPolicy
from temporalio import workflow
from models.data_types import ConversationHistory, NextStep, ValidationInput
with workflow.unsafe.imports_passed_through():
from activities.tool_activities import ToolActivities, ToolPromptInput
from activities.tool_activities import ToolActivities
from prompts.agent_prompt_generators import generate_genai_prompt
from models.data_types import CombinedInput, ToolWorkflowParams
from models.data_types import CombinedInput, ToolWorkflowParams, ToolPromptInput
# Constants
MAX_TURNS_BEFORE_CONTINUE = 250
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
# Type definitions
Message = Dict[str, Union[str, Dict[str, Any]]]
ConversationHistory = Dict[str, List[Message]]
NextStep = Literal["confirm", "question", "done"]
LLM_ACTIVITY_TIMEOUT = timedelta(minutes=30)
class ToolData(TypedDict, total=False):
@@ -153,6 +150,26 @@ class ToolWorkflow:
prompt = self.prompt_queue.popleft()
if not prompt.startswith("###"):
self.add_message("user", prompt)
# Validate the prompt before proceeding
validation_input = ValidationInput(
prompt=prompt,
conversation_history=self.conversation_history,
agent_goal=agent_goal,
)
validation_result = await workflow.execute_activity(
ToolActivities.validate_llm_prompt,
args=[validation_input],
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=5)),
)
if not validation_result.validationResult:
# Handle validation failure
self.add_message("agent", validation_result.validationFailedReason)
continue # Skip to the next iteration
# Proceed with generating the context and prompt
context_instructions = generate_genai_prompt(
agent_goal, self.conversation_history, self.tool_data