fixes to issues 1 2 and 3. Plus tuning

This commit is contained in:
Steve Androulakis
2025-01-24 15:23:57 -08:00
parent caf5812f90
commit 7977894f64
8 changed files with 154 additions and 39 deletions

View File

@@ -5,6 +5,11 @@ RAPIDAPI_HOST=sky-scrapper.p.rapidapi.com
STRIPE_API_KEY=sk_test_51J... STRIPE_API_KEY=sk_test_51J...
LLM_PROVIDER=openai # default
# or
# LLM_PROVIDER=ollama
# OLLAMA_MODEL_NAME=qwen2.5:14b
# uncomment and unset these environment variables to connect to the local dev server # uncomment and unset these environment variables to connect to the local dev server
# TEMPORAL_ADDRESS=namespace.acct.tmprl.cloud:7233 # TEMPORAL_ADDRESS=namespace.acct.tmprl.cloud:7233
# TEMPORAL_NAMESPACE=default # TEMPORAL_NAMESPACE=default

View File

@@ -14,7 +14,28 @@ This application uses `.env` files for configuration. Copy the [.env.example](.e
cp .env.example .env cp .env.example .env
``` ```
The agent requires an OpenAI key for the gpt-4o model. Set this in the `OPENAI_API_KEY` environment variable in .env ### LLM Provider Configuration
The agent can use either OpenAI's GPT-4o or a local LLM via Ollama. Set the `LLM_PROVIDER` environment variable in your `.env` file to choose the desired provider:
- `LLM_PROVIDER=openai` for OpenAI's GPT-4o
- `LLM_PROVIDER=ollama` for the local LLM via Ollama (not recommended for this use case)
### OpenAI Configuration
If using OpenAI, ensure you have an OpenAI key for the GPT-4o model. Set this in the `OPENAI_API_KEY` environment variable in `.env`.
### Ollama Configuration
To use a local LLM with Ollama:
1. Install [Ollama](https://ollama.com) and the [Qwen2.5 14B](https://ollama.com/library/qwen2.5) model.
- Run `ollama run <OLLAMA_MODEL_NAME>` to start the model. Note that this model is about 9GB to download.
- Example: `ollama run qwen2.5:14b`
2. Set `LLM_PROVIDER=ollama` in your `.env` file and `OLLAMA_MODEL_NAME` to the name of the model you installed.
Note: The local LLM is disabled by default as ChatGPT 4o was found to be MUCH more reliable for this use case. However, you can switch to Ollama if desired.
## Agent Tools ## Agent Tools
* Requires a Rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env * Requires a Rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env
@@ -85,12 +106,6 @@ Access the UI at `http://localhost:5173`
- Note the mapping in `tools/__init__.py` to each tool - Note the mapping in `tools/__init__.py` to each tool
- See main.py where some tool-specific logic is defined (todo, move this to the tool definition) - See main.py where some tool-specific logic is defined (todo, move this to the tool definition)
## Using a local LLM instead of ChatGPT 4o
With a small code change, the agent can use local LLMs.
* Install [Ollama](https://ollama.com) and the [Qwen2.5 14B](https://ollama.com/library/qwen2.5) model (`ollama run qwen2.5:14b`). (note this model is about 9GB to download).
* Local LLM is disabled as ChatGPT 4o was better for this use case. To use Ollama, examine `./activities/tool_activities.py` and rename the existing functions.
* Note that Qwen2.5 14B is not as good as ChatGPT 4o for this use case and will perform worse at moving the conversation towards the goal.
## TODO ## TODO
- I should prove this out with other tool definitions outside of the event/flight search case (take advantage of my nice DSL). - I should prove this out with other tool definitions outside of the event/flight search case (take advantage of my nice DSL).
- Currently hardcoded to the Temporal dev server at localhost:7233. Need to support options incl Temporal Cloud. - Currently hardcoded to the Temporal dev server at localhost:7233. Need to support options incl Temporal Cloud.

View File

@@ -18,6 +18,14 @@ class ToolPromptInput:
class ToolActivities: class ToolActivities:
@activity.defn @activity.defn
def prompt_llm(self, input: ToolPromptInput) -> dict: def prompt_llm(self, input: ToolPromptInput) -> dict:
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
if llm_provider == "ollama":
return self.prompt_llm_ollama(input)
else:
return self.prompt_llm_openai(input)
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
client = OpenAI( client = OpenAI(
api_key=os.environ.get( api_key=os.environ.get(
"OPENAI_API_KEY" "OPENAI_API_KEY"
@@ -44,9 +52,8 @@ class ToolActivities:
response_content = chat_completion.choices[0].message.content response_content = chat_completion.choices[0].message.content
print(f"ChatGPT response: {response_content}") print(f"ChatGPT response: {response_content}")
# Trim formatting markers if present # Use the new sanitize function
if response_content.startswith("```json") and response_content.endswith("```"): response_content = self.sanitize_json_response(response_content)
response_content = response_content[7:-3].strip()
try: try:
data = json.loads(response_content) data = json.loads(response_content)
@@ -58,7 +65,7 @@ class ToolActivities:
@activity.defn @activity.defn
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict: def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
model_name = "qwen2.5:14b" model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
messages = [ messages = [
{ {
"role": "system", "role": "system",
@@ -76,8 +83,11 @@ class ToolActivities:
print(f"Chat response: {response.message.content}") print(f"Chat response: {response.message.content}")
# Use the new sanitize function
response_content = self.sanitize_json_response(response.message.content)
try: try:
data = json.loads(response.message.content) data = json.loads(response_content)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}") print(f"Invalid JSON: {e}")
print(response.message.content) print(response.message.content)
@@ -85,6 +95,54 @@ class ToolActivities:
return data return data
def sanitize_json_response(self, response_content: str) -> str:
"""
Extracts the JSON block from the response content as a string.
Supports:
- JSON surrounded by ```json and ```
- Raw JSON input
- JSON preceded or followed by extra text
Rejects invalid input that doesn't contain JSON.
"""
try:
start_marker = "```json"
end_marker = "```"
json_str = None
# Case 1: JSON surrounded by markers
if start_marker in response_content and end_marker in response_content:
json_start = response_content.index(start_marker) + len(start_marker)
json_end = response_content.index(end_marker, json_start)
json_str = response_content[json_start:json_end].strip()
# Case 2: Text with valid JSON
else:
# Try to locate the JSON block by scanning for the first `{` and last `}`
json_start = response_content.find("{")
json_end = response_content.rfind("}")
if json_start != -1 and json_end != -1 and json_start < json_end:
json_str = response_content[json_start : json_end + 1].strip()
# Validate and ensure the extracted JSON is valid
if json_str:
json.loads(json_str) # This will raise an error if the JSON is invalid
return json_str
# If no valid JSON found, raise an error
raise ValueError("Response does not contain valid JSON.")
except json.JSONDecodeError:
# Invalid JSON
print(f"Invalid JSON detected in response: {response_content}")
raise ValueError("Response does not contain valid JSON.")
except Exception as e:
# Other errors
print(f"Error processing response: {str(e)}")
print(f"Full response: {response_content}")
raise
def get_current_date_human_readable(): def get_current_date_human_readable():
""" """

View File

@@ -1,14 +1,14 @@
from fastapi import FastAPI from fastapi import FastAPI
from typing import Optional from typing import Optional
from temporalio.client import Client from temporalio.client import Client
from temporalio.exceptions import TemporalError
from temporalio.api.enums.v1 import WorkflowExecutionStatus
from workflows.tool_workflow import ToolWorkflow from workflows.tool_workflow import ToolWorkflow
from models.data_types import CombinedInput, ToolWorkflowParams from models.data_types import CombinedInput, ToolWorkflowParams
from tools.goal_registry import goal_event_flight_invoice from tools.goal_registry import goal_event_flight_invoice
from temporalio.exceptions import TemporalError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
app = FastAPI() app = FastAPI()
temporal_client: Optional[Client] = None temporal_client: Optional[Client] = None
@@ -58,11 +58,32 @@ async def get_conversation_history():
"""Calls the workflow's 'get_conversation_history' query.""" """Calls the workflow's 'get_conversation_history' query."""
try: try:
handle = temporal_client.get_workflow_handle("agent-workflow") handle = temporal_client.get_workflow_handle("agent-workflow")
conversation_history = await handle.query("get_conversation_history")
status_names = {
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED",
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED",
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED"
}
failed_states = [
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED,
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED,
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED
]
# Check workflow status first
description = await handle.describe()
if description.status in failed_states:
status_name = status_names.get(description.status, "UNKNOWN_STATUS")
print(f"Workflow is in {status_name} state. Returning empty history.")
return []
# Only query if workflow is running
conversation_history = await handle.query("get_conversation_history")
return conversation_history return conversation_history
except TemporalError as e: except TemporalError as e:
print(e) print(f"Temporal error: {e}")
return [] return []

View File

@@ -21,7 +21,7 @@ class ChatErrorBoundary extends React.Component {
if (this.state.hasError) { if (this.state.hasError) {
return ( return (
<div className="text-red-500 p-4 text-center"> <div className="text-red-500 p-4 text-center">
Something went wrong. Please refresh the page. Something went wrong. Please Terminate the workflow and try again.
</div> </div>
); );
} }

View File

@@ -53,7 +53,7 @@ def find_events(args: dict) -> dict:
"dateFrom": event["dateFrom"], "dateFrom": event["dateFrom"],
"dateTo": event["dateTo"], "dateTo": event["dateTo"],
"description": event["description"], "description": event["description"],
"monthContext": month_context, "month": month_context,
} }
) )

View File

@@ -3,17 +3,19 @@ from models.tool_definitions import ToolDefinition, ToolArgument
find_events_tool = ToolDefinition( find_events_tool = ToolDefinition(
name="FindEvents", name="FindEvents",
description="Find upcoming events to travel to a given city (e.g., 'Melbourne') and a date or month. " description="Find upcoming events to travel to a given city (e.g., 'Melbourne') and a date or month. "
"It knows about events in Oceania only (e.g. major Australian and New Zealand cities).", "It knows about events in Oceania only (e.g. major Australian and New Zealand cities). "
"It will search 1 month either side of the month provided. "
"Returns a list of events. ",
arguments=[ arguments=[
ToolArgument( ToolArgument(
name="city", name="city",
type="string", type="string",
description="Which city to search for events", description="Which city to search for events",
), ),
ToolArgument( ToolArgument(
name="month", name="month",
type="string", type="string",
description="The month or approximate date range to find events", description="The month to search for events (will search 1 month either side of the month provided)",
), ),
], ],
) )

View File

@@ -20,12 +20,14 @@ Message = Dict[str, Union[str, Dict[str, Any]]]
ConversationHistory = Dict[str, List[Message]] ConversationHistory = Dict[str, List[Message]]
NextStep = Literal["confirm", "question", "done"] NextStep = Literal["confirm", "question", "done"]
class ToolData(TypedDict, total=False): class ToolData(TypedDict, total=False):
next: NextStep next: NextStep
tool: str tool: str
args: Dict[str, Any] args: Dict[str, Any]
response: str response: str
@workflow.defn @workflow.defn
class ToolWorkflow: class ToolWorkflow:
"""Workflow that manages tool execution with user confirmation and conversation history.""" """Workflow that manages tool execution with user confirmation and conversation history."""
@@ -39,7 +41,9 @@ class ToolWorkflow:
self.confirm: bool = False self.confirm: bool = False
self.tool_results: List[Dict[str, Any]] = [] self.tool_results: List[Dict[str, Any]] = []
async def _handle_tool_execution(self, current_tool: str, tool_data: ToolData) -> None: async def _handle_tool_execution(
self, current_tool: str, tool_data: ToolData
) -> None:
"""Execute a tool after confirmation and handle its result.""" """Execute a tool after confirmation and handle its result."""
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}") workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
@@ -49,15 +53,23 @@ class ToolWorkflow:
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT, schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
) )
dynamic_result["tool"] = current_tool dynamic_result["tool"] = current_tool
self.add_message("tool_result", {"tool": current_tool, "result": dynamic_result}) self.add_message(
"tool_result", {"tool": current_tool, "result": dynamic_result}
)
self.prompt_queue.append( self.prompt_queue.append(
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. " f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. " "INSTRUCTIONS: Parse this tool result as plain text, and use the system prompt containing the list of tools in sequence and the conversation history to figure out next steps, if any. "
"DON'T ask any clarifying questions that are outside of the tools and args specified. " '{"next": "<question|confirm|done>", "tool": "<tool_name or null>", "args": {"<arg1>": "<value1 or null>", "<arg2>": "<value2 or null>}, "response": "<plain text>"}'
"ONLY return those json keys (next, tool, args, response), nothing else."
'Next should only be "done" if all tools have been run (use the system prompt to figure that out).'
'Next should be "question" if the tool is not the last one in the sequence.'
'Next should NOT be "confirm" at this point.'
) )
async def _handle_missing_args(self, current_tool: str, args: Dict[str, Any], tool_data: ToolData) -> bool: async def _handle_missing_args(
self, current_tool: str, args: Dict[str, Any], tool_data: ToolData
) -> bool:
"""Check for missing arguments and handle them if found.""" """Check for missing arguments and handle them if found."""
missing_args = [key for key, value in args.items() if value is None] missing_args = [key for key, value in args.items() if value is None]
@@ -67,7 +79,9 @@ class ToolWorkflow:
f"and following missing arguments for tool {current_tool}: {missing_args}. " f"and following missing arguments for tool {current_tool}: {missing_args}. "
"Only provide a valid JSON response without any comments or metadata." "Only provide a valid JSON response without any comments or metadata."
) )
workflow.logger.info(f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}") workflow.logger.info(
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
)
return True return True
return False return False
@@ -76,15 +90,16 @@ class ToolWorkflow:
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE: if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
summary_context, summary_prompt = self.prompt_summary_with_history() summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = ToolPromptInput( summary_input = ToolPromptInput(
prompt=summary_prompt, prompt=summary_prompt, context_instructions=summary_context
context_instructions=summary_context
) )
self.conversation_summary = await workflow.start_activity_method( self.conversation_summary = await workflow.start_activity_method(
ToolActivities.prompt_llm, ToolActivities.prompt_llm,
summary_input, summary_input,
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT, schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
) )
workflow.logger.info(f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns.") workflow.logger.info(
f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns."
)
workflow.continue_as_new( workflow.continue_as_new(
args=[ args=[
CombinedInput( CombinedInput(
@@ -152,8 +167,7 @@ class ToolWorkflow:
prompt_input, prompt_input,
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT, schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
retry_policy=RetryPolicy( retry_policy=RetryPolicy(
maximum_attempts=5, maximum_attempts=5, initial_interval=timedelta(seconds=15)
initial_interval=timedelta(seconds=12)
), ),
) )
self.tool_data = tool_data self.tool_data = tool_data