mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 05:58:08 +01:00
Merge pull request #5 from steveandroulakis/validator-and-improvements
Validator and improvements
This commit is contained in:
10
README.md
10
README.md
@@ -68,6 +68,7 @@ Note: I found the other (hosted) LLMs to be MUCH more reliable for this use case
|
|||||||
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
|
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
|
||||||
* It's free to sign up and get a key at [Stripe](https://stripe.com/)
|
* It's free to sign up and get a key at [Stripe](https://stripe.com/)
|
||||||
* If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file.
|
* If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file.
|
||||||
|
* Requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token. Set `FOOTBALL_DATA_API_KEY` to this value.
|
||||||
|
|
||||||
## Configuring Temporal Connection
|
## Configuring Temporal Connection
|
||||||
|
|
||||||
@@ -137,3 +138,12 @@ Access the UI at `http://localhost:5173`
|
|||||||
- Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution.
|
- Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution.
|
||||||
- Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output)
|
- Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output)
|
||||||
- Tests would be nice!
|
- Tests would be nice!
|
||||||
|
|
||||||
|
# TODO for this branch
|
||||||
|
## Agent
|
||||||
|
- We'll have to figure out which matches are where. No use going to Manchester for a match that isn't there.
|
||||||
|
- The use of `###` in prompts I want excluded from the conversation history is a bit of a hack.
|
||||||
|
|
||||||
|
## Validator function
|
||||||
|
- Probably keep data types, but move the activity and workflow code for the demo
|
||||||
|
- Probably don't need the validator function if its the result from a tool call or confirmation step
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from temporalio import activity
|
from temporalio import activity
|
||||||
from ollama import chat, ChatResponse
|
from ollama import chat, ChatResponse
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -11,6 +10,7 @@ import google.generativeai as genai
|
|||||||
import anthropic
|
import anthropic
|
||||||
import deepseek
|
import deepseek
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from models.data_types import ValidationInput, ValidationResult, ToolPromptInput
|
||||||
|
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
print(
|
print(
|
||||||
@@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama":
|
|||||||
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
|
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolPromptInput:
|
|
||||||
prompt: str
|
|
||||||
context_instructions: str
|
|
||||||
|
|
||||||
|
|
||||||
class ToolActivities:
|
class ToolActivities:
|
||||||
|
@activity.defn
|
||||||
|
async def validate_llm_prompt(
|
||||||
|
self, validation_input: ValidationInput
|
||||||
|
) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validates the prompt in the context of the conversation history and agent goal.
|
||||||
|
Returns a ValidationResult indicating if the prompt makes sense given the context.
|
||||||
|
"""
|
||||||
|
# Create simple context string describing tools and goals
|
||||||
|
tools_description = []
|
||||||
|
for tool in validation_input.agent_goal.tools:
|
||||||
|
tool_str = f"Tool: {tool.name}\n"
|
||||||
|
tool_str += f"Description: {tool.description}\n"
|
||||||
|
tool_str += "Arguments: " + ", ".join(
|
||||||
|
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
|
||||||
|
)
|
||||||
|
tools_description.append(tool_str)
|
||||||
|
tools_str = "\n".join(tools_description)
|
||||||
|
|
||||||
|
# Convert conversation history to string
|
||||||
|
history_str = json.dumps(validation_input.conversation_history, indent=2)
|
||||||
|
|
||||||
|
# Create context instructions
|
||||||
|
context_instructions = f"""The agent goal and tools are as follows:
|
||||||
|
Description: {validation_input.agent_goal.description}
|
||||||
|
Available Tools:
|
||||||
|
{tools_str}
|
||||||
|
The conversation history to date is:
|
||||||
|
{history_str}"""
|
||||||
|
|
||||||
|
# Create validation prompt
|
||||||
|
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
|
||||||
|
Please validate if this prompt makes sense given the agent goal and conversation history.
|
||||||
|
If the prompt doesn't make sense toward the goal then validationResult should be true.
|
||||||
|
Only return false if the prompt is nonsensical given the goal, tools available, and conversation history.
|
||||||
|
Return ONLY a JSON object with the following structure:
|
||||||
|
"validationResult": true/false,
|
||||||
|
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user
|
||||||
|
about why their request doesn't make sense in the context and what information they should provide instead.
|
||||||
|
validationFailedReason should contain JSON in the format
|
||||||
|
{{
|
||||||
|
"next": "question",
|
||||||
|
"response": "[your reason here and a response to get the user back on track with the agent goal]"
|
||||||
|
}}
|
||||||
|
If validationResult is true, return an empty dict {{}}"
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Call the LLM with the validation prompt
|
||||||
|
prompt_input = ToolPromptInput(
|
||||||
|
prompt=validation_prompt, context_instructions=context_instructions
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.prompt_llm(prompt_input)
|
||||||
|
|
||||||
|
return ValidationResult(
|
||||||
|
validationResult=result.get("validationResult", False),
|
||||||
|
validationFailedReason=result.get("validationFailedReason", {}),
|
||||||
|
)
|
||||||
|
|
||||||
@activity.defn
|
@activity.defn
|
||||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||||
|
|||||||
23
api/main.py
23
api/main.py
@@ -132,3 +132,26 @@ async def end_chat():
|
|||||||
print(e)
|
print(e)
|
||||||
# Workflow not found; return an empty response
|
# Workflow not found; return an empty response
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/start-workflow")
|
||||||
|
async def start_workflow():
|
||||||
|
# Create combined input
|
||||||
|
combined_input = CombinedInput(
|
||||||
|
tool_params=ToolWorkflowParams(None, None),
|
||||||
|
agent_goal=goal_event_flight_invoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_id = "agent-workflow"
|
||||||
|
|
||||||
|
# Start the workflow with the starter prompt from the goal
|
||||||
|
await temporal_client.start_workflow(
|
||||||
|
ToolWorkflow.run,
|
||||||
|
combined_input,
|
||||||
|
id=workflow_id,
|
||||||
|
task_queue=TEMPORAL_TASK_QUEUE,
|
||||||
|
start_signal="user_prompt",
|
||||||
|
start_signal_args=["### " + goal_event_flight_invoice.starter_prompt],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."}
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ export default function App() {
|
|||||||
try {
|
try {
|
||||||
setError(INITIAL_ERROR_STATE);
|
setError(INITIAL_ERROR_STATE);
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
await apiService.sendMessage("I'd like to travel for an event.");
|
await apiService.startWorkflow();
|
||||||
setConversation([]);
|
setConversation([]);
|
||||||
setLastMessage(null);
|
setLastMessage(null);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
@@ -56,6 +56,26 @@ export const apiService = {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
async startWorkflow() {
|
||||||
|
try {
|
||||||
|
const res = await fetch(
|
||||||
|
`${API_BASE_URL}/start-workflow`,
|
||||||
|
{
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
return handleResponse(res);
|
||||||
|
} catch (error) {
|
||||||
|
throw new ApiError(
|
||||||
|
'Failed to start workflow',
|
||||||
|
error.status || 500
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
async confirm() {
|
async confirm() {
|
||||||
try {
|
try {
|
||||||
const res = await fetch(`${API_BASE_URL}/confirm`, {
|
const res = await fetch(`${API_BASE_URL}/confirm`, {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Deque
|
from typing import Optional, Deque, Dict, Any, List, Union, Literal
|
||||||
from models.tool_definitions import AgentGoal
|
from models.tool_definitions import AgentGoal
|
||||||
|
|
||||||
|
|
||||||
@@ -13,3 +13,32 @@ class ToolWorkflowParams:
|
|||||||
class CombinedInput:
|
class CombinedInput:
|
||||||
tool_params: ToolWorkflowParams
|
tool_params: ToolWorkflowParams
|
||||||
agent_goal: AgentGoal
|
agent_goal: AgentGoal
|
||||||
|
|
||||||
|
|
||||||
|
Message = Dict[str, Union[str, Dict[str, Any]]]
|
||||||
|
ConversationHistory = Dict[str, List[Message]]
|
||||||
|
NextStep = Literal["confirm", "question", "done"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolPromptInput:
|
||||||
|
prompt: str
|
||||||
|
context_instructions: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationInput:
|
||||||
|
prompt: str
|
||||||
|
conversation_history: ConversationHistory
|
||||||
|
agent_goal: AgentGoal
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
validationResult: bool
|
||||||
|
validationFailedReason: dict = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Initialize empty dict if None
|
||||||
|
if self.validationFailedReason is None:
|
||||||
|
self.validationFailedReason = {}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class ToolDefinition:
|
|||||||
class AgentGoal:
|
class AgentGoal:
|
||||||
tools: List[ToolDefinition]
|
tools: List[ToolDefinition]
|
||||||
description: str = "Description of the tools purpose and overall goal"
|
description: str = "Description of the tools purpose and overall goal"
|
||||||
|
starter_prompt: str = "Initial prompt to start the conversation"
|
||||||
example_conversation_history: str = (
|
example_conversation_history: str = (
|
||||||
"Example conversation history to help the AI agent understand the context of the conversation"
|
"Example conversation history to help the AI agent understand the context of the conversation"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ async def main():
|
|||||||
workflows=[ToolWorkflow],
|
workflows=[ToolWorkflow],
|
||||||
activities=[
|
activities=[
|
||||||
activities.prompt_llm,
|
activities.prompt_llm,
|
||||||
|
activities.validate_llm_prompt,
|
||||||
dynamic_tool_activity,
|
dynamic_tool_activity,
|
||||||
],
|
],
|
||||||
activity_executor=activity_executor,
|
activity_executor=activity_executor,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ goal_event_flight_invoice = AgentGoal(
|
|||||||
"1. FindFixtures: Find fixtures for a team in a given month "
|
"1. FindFixtures: Find fixtures for a team in a given month "
|
||||||
"2. SearchFlights: search for a flight around the event dates "
|
"2. SearchFlights: search for a flight around the event dates "
|
||||||
"3. CreateInvoice: Create a simple invoice for the cost of that flight ",
|
"3. CreateInvoice: Create a simple invoice for the cost of that flight ",
|
||||||
|
starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job",
|
||||||
example_conversation_history="\n ".join(
|
example_conversation_history="\n ".join(
|
||||||
[
|
[
|
||||||
"user: I'd like to travel to a football match",
|
"user: I'd like to travel to a football match",
|
||||||
|
|||||||
@@ -1,24 +1,21 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
|
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
|
||||||
|
|
||||||
from temporalio.common import RetryPolicy
|
from temporalio.common import RetryPolicy
|
||||||
from temporalio import workflow
|
from temporalio import workflow
|
||||||
|
|
||||||
|
from models.data_types import ConversationHistory, NextStep, ValidationInput
|
||||||
|
|
||||||
with workflow.unsafe.imports_passed_through():
|
with workflow.unsafe.imports_passed_through():
|
||||||
from activities.tool_activities import ToolActivities, ToolPromptInput
|
from activities.tool_activities import ToolActivities
|
||||||
from prompts.agent_prompt_generators import generate_genai_prompt
|
from prompts.agent_prompt_generators import generate_genai_prompt
|
||||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
from models.data_types import CombinedInput, ToolWorkflowParams, ToolPromptInput
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MAX_TURNS_BEFORE_CONTINUE = 250
|
MAX_TURNS_BEFORE_CONTINUE = 250
|
||||||
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
|
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
|
||||||
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
|
LLM_ACTIVITY_TIMEOUT = timedelta(minutes=30)
|
||||||
|
|
||||||
# Type definitions
|
|
||||||
Message = Dict[str, Union[str, Dict[str, Any]]]
|
|
||||||
ConversationHistory = Dict[str, List[Message]]
|
|
||||||
NextStep = Literal["confirm", "question", "done"]
|
|
||||||
|
|
||||||
|
|
||||||
class ToolData(TypedDict, total=False):
|
class ToolData(TypedDict, total=False):
|
||||||
@@ -153,6 +150,26 @@ class ToolWorkflow:
|
|||||||
prompt = self.prompt_queue.popleft()
|
prompt = self.prompt_queue.popleft()
|
||||||
if not prompt.startswith("###"):
|
if not prompt.startswith("###"):
|
||||||
self.add_message("user", prompt)
|
self.add_message("user", prompt)
|
||||||
|
|
||||||
|
# Validate the prompt before proceeding
|
||||||
|
validation_input = ValidationInput(
|
||||||
|
prompt=prompt,
|
||||||
|
conversation_history=self.conversation_history,
|
||||||
|
agent_goal=agent_goal,
|
||||||
|
)
|
||||||
|
validation_result = await workflow.execute_activity(
|
||||||
|
ToolActivities.validate_llm_prompt,
|
||||||
|
args=[validation_input],
|
||||||
|
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
|
||||||
|
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=5)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not validation_result.validationResult:
|
||||||
|
# Handle validation failure
|
||||||
|
self.add_message("agent", validation_result.validationFailedReason)
|
||||||
|
continue # Skip to the next iteration
|
||||||
|
|
||||||
|
# Proceed with generating the context and prompt
|
||||||
|
|
||||||
context_instructions = generate_genai_prompt(
|
context_instructions = generate_genai_prompt(
|
||||||
agent_goal, self.conversation_history, self.tool_data
|
agent_goal, self.conversation_history, self.tool_data
|
||||||
|
|||||||
Reference in New Issue
Block a user