From 7977894f64556c5343291b9e8f4a82e63e069bdd Mon Sep 17 00:00:00 2001 From: Steve Androulakis Date: Fri, 24 Jan 2025 15:23:57 -0800 Subject: [PATCH] fixes to issues 1 2 and 3. Plus tuning --- .env.example | 7 ++- README.md | 29 ++++++++--- activities/tool_activities.py | 68 ++++++++++++++++++++++++-- api/main.py | 29 +++++++++-- frontend/src/components/ChatWindow.jsx | 2 +- tools/find_events.py | 2 +- tools/tool_registry.py | 8 +-- workflows/tool_workflow.py | 48 +++++++++++------- 8 files changed, 154 insertions(+), 39 deletions(-) diff --git a/.env.example b/.env.example index 09558ea..3f29a28 100644 --- a/.env.example +++ b/.env.example @@ -5,6 +5,11 @@ RAPIDAPI_HOST=sky-scrapper.p.rapidapi.com STRIPE_API_KEY=sk_test_51J... +LLM_PROVIDER=openai # default +# or +# LLM_PROVIDER=ollama +# OLLAMA_MODEL_NAME=qwen2.5:14b + # uncomment and unset these environment variables to connect to the local dev server # TEMPORAL_ADDRESS=namespace.acct.tmprl.cloud:7233 # TEMPORAL_NAMESPACE=default @@ -15,4 +20,4 @@ STRIPE_API_KEY=sk_test_51J... # TEMPORAL_TLS_KEY='path/to/key.pem' # Uncomment if using API key (not needed for local dev server) -# TEMPORAL_API_KEY=abcdef1234567890 \ No newline at end of file +# TEMPORAL_API_KEY=abcdef1234567890 diff --git a/README.md b/README.md index 229dce8..72f36d2 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,28 @@ This application uses `.env` files for configuration. Copy the [.env.example](.e cp .env.example .env ``` -The agent requires an OpenAI key for the gpt-4o model. Set this in the `OPENAI_API_KEY` environment variable in .env +### LLM Provider Configuration + +The agent can use either OpenAI's GPT-4o or a local LLM via Ollama. Set the `LLM_PROVIDER` environment variable in your `.env` file to choose the desired provider: + +- `LLM_PROVIDER=openai` for OpenAI's GPT-4o +- `LLM_PROVIDER=ollama` for the local LLM via Ollama (not recommended for this use case) + +### OpenAI Configuration + +If using OpenAI, ensure you have an OpenAI key for the GPT-4o model. Set this in the `OPENAI_API_KEY` environment variable in `.env`. + +### Ollama Configuration + +To use a local LLM with Ollama: + +1. Install [Ollama](https://ollama.com) and the [Qwen2.5 14B](https://ollama.com/library/qwen2.5) model. + - Run `ollama run ` to start the model. Note that this model is about 9GB to download. + - Example: `ollama run qwen2.5:14b` + +2. Set `LLM_PROVIDER=ollama` in your `.env` file and `OLLAMA_MODEL_NAME` to the name of the model you installed. + +Note: The local LLM is disabled by default as ChatGPT 4o was found to be MUCH more reliable for this use case. However, you can switch to Ollama if desired. ## Agent Tools * Requires a Rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env @@ -85,12 +106,6 @@ Access the UI at `http://localhost:5173` - Note the mapping in `tools/__init__.py` to each tool - See main.py where some tool-specific logic is defined (todo, move this to the tool definition) -## Using a local LLM instead of ChatGPT 4o -With a small code change, the agent can use local LLMs. -* Install [Ollama](https://ollama.com) and the [Qwen2.5 14B](https://ollama.com/library/qwen2.5) model (`ollama run qwen2.5:14b`). (note this model is about 9GB to download). - * Local LLM is disabled as ChatGPT 4o was better for this use case. To use Ollama, examine `./activities/tool_activities.py` and rename the existing functions. - * Note that Qwen2.5 14B is not as good as ChatGPT 4o for this use case and will perform worse at moving the conversation towards the goal. - ## TODO - I should prove this out with other tool definitions outside of the event/flight search case (take advantage of my nice DSL). - Currently hardcoded to the Temporal dev server at localhost:7233. Need to support options incl Temporal Cloud. diff --git a/activities/tool_activities.py b/activities/tool_activities.py index e7fbb67..215c289 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -18,6 +18,14 @@ class ToolPromptInput: class ToolActivities: @activity.defn def prompt_llm(self, input: ToolPromptInput) -> dict: + llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower() + + if llm_provider == "ollama": + return self.prompt_llm_ollama(input) + else: + return self.prompt_llm_openai(input) + + def prompt_llm_openai(self, input: ToolPromptInput) -> dict: client = OpenAI( api_key=os.environ.get( "OPENAI_API_KEY" @@ -44,9 +52,8 @@ class ToolActivities: response_content = chat_completion.choices[0].message.content print(f"ChatGPT response: {response_content}") - # Trim formatting markers if present - if response_content.startswith("```json") and response_content.endswith("```"): - response_content = response_content[7:-3].strip() + # Use the new sanitize function + response_content = self.sanitize_json_response(response_content) try: data = json.loads(response_content) @@ -58,7 +65,7 @@ class ToolActivities: @activity.defn def prompt_llm_ollama(self, input: ToolPromptInput) -> dict: - model_name = "qwen2.5:14b" + model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b") messages = [ { "role": "system", @@ -76,8 +83,11 @@ class ToolActivities: print(f"Chat response: {response.message.content}") + # Use the new sanitize function + response_content = self.sanitize_json_response(response.message.content) + try: - data = json.loads(response.message.content) + data = json.loads(response_content) except json.JSONDecodeError as e: print(f"Invalid JSON: {e}") print(response.message.content) @@ -85,6 +95,54 @@ class ToolActivities: return data + def sanitize_json_response(self, response_content: str) -> str: + """ + Extracts the JSON block from the response content as a string. + Supports: + - JSON surrounded by ```json and ``` + - Raw JSON input + - JSON preceded or followed by extra text + Rejects invalid input that doesn't contain JSON. + """ + try: + start_marker = "```json" + end_marker = "```" + + json_str = None + + # Case 1: JSON surrounded by markers + if start_marker in response_content and end_marker in response_content: + json_start = response_content.index(start_marker) + len(start_marker) + json_end = response_content.index(end_marker, json_start) + json_str = response_content[json_start:json_end].strip() + + # Case 2: Text with valid JSON + else: + # Try to locate the JSON block by scanning for the first `{` and last `}` + json_start = response_content.find("{") + json_end = response_content.rfind("}") + + if json_start != -1 and json_end != -1 and json_start < json_end: + json_str = response_content[json_start : json_end + 1].strip() + + # Validate and ensure the extracted JSON is valid + if json_str: + json.loads(json_str) # This will raise an error if the JSON is invalid + return json_str + + # If no valid JSON found, raise an error + raise ValueError("Response does not contain valid JSON.") + + except json.JSONDecodeError: + # Invalid JSON + print(f"Invalid JSON detected in response: {response_content}") + raise ValueError("Response does not contain valid JSON.") + except Exception as e: + # Other errors + print(f"Error processing response: {str(e)}") + print(f"Full response: {response_content}") + raise + def get_current_date_human_readable(): """ diff --git a/api/main.py b/api/main.py index 0ed92e4..252b714 100644 --- a/api/main.py +++ b/api/main.py @@ -1,14 +1,14 @@ from fastapi import FastAPI from typing import Optional from temporalio.client import Client +from temporalio.exceptions import TemporalError +from temporalio.api.enums.v1 import WorkflowExecutionStatus from workflows.tool_workflow import ToolWorkflow from models.data_types import CombinedInput, ToolWorkflowParams from tools.goal_registry import goal_event_flight_invoice -from temporalio.exceptions import TemporalError from fastapi.middleware.cors import CORSMiddleware from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE - app = FastAPI() temporal_client: Optional[Client] = None @@ -58,11 +58,32 @@ async def get_conversation_history(): """Calls the workflow's 'get_conversation_history' query.""" try: handle = temporal_client.get_workflow_handle("agent-workflow") - conversation_history = await handle.query("get_conversation_history") + status_names = { + WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED", + WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED", + WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED" + } + + failed_states = [ + WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED, + WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED, + WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED + ] + + # Check workflow status first + description = await handle.describe() + if description.status in failed_states: + status_name = status_names.get(description.status, "UNKNOWN_STATUS") + print(f"Workflow is in {status_name} state. Returning empty history.") + return [] + + # Only query if workflow is running + conversation_history = await handle.query("get_conversation_history") return conversation_history + except TemporalError as e: - print(e) + print(f"Temporal error: {e}") return [] diff --git a/frontend/src/components/ChatWindow.jsx b/frontend/src/components/ChatWindow.jsx index 82e1042..21a12d3 100644 --- a/frontend/src/components/ChatWindow.jsx +++ b/frontend/src/components/ChatWindow.jsx @@ -21,7 +21,7 @@ class ChatErrorBoundary extends React.Component { if (this.state.hasError) { return (
- Something went wrong. Please refresh the page. + Something went wrong. Please Terminate the workflow and try again.
); } diff --git a/tools/find_events.py b/tools/find_events.py index 364afbc..51f3d42 100644 --- a/tools/find_events.py +++ b/tools/find_events.py @@ -53,7 +53,7 @@ def find_events(args: dict) -> dict: "dateFrom": event["dateFrom"], "dateTo": event["dateTo"], "description": event["description"], - "monthContext": month_context, + "month": month_context, } ) diff --git a/tools/tool_registry.py b/tools/tool_registry.py index 0c39821..a429c81 100644 --- a/tools/tool_registry.py +++ b/tools/tool_registry.py @@ -3,17 +3,19 @@ from models.tool_definitions import ToolDefinition, ToolArgument find_events_tool = ToolDefinition( name="FindEvents", description="Find upcoming events to travel to a given city (e.g., 'Melbourne') and a date or month. " - "It knows about events in Oceania only (e.g. major Australian and New Zealand cities).", + "It knows about events in Oceania only (e.g. major Australian and New Zealand cities). " + "It will search 1 month either side of the month provided. " + "Returns a list of events. ", arguments=[ ToolArgument( name="city", type="string", - description="Which city to search for events", + description="Which city to search for events", ), ToolArgument( name="month", type="string", - description="The month or approximate date range to find events", + description="The month to search for events (will search 1 month either side of the month provided)", ), ], ) diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index b5463c4..c325aeb 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -20,12 +20,14 @@ Message = Dict[str, Union[str, Dict[str, Any]]] ConversationHistory = Dict[str, List[Message]] NextStep = Literal["confirm", "question", "done"] + class ToolData(TypedDict, total=False): next: NextStep tool: str args: Dict[str, Any] response: str + @workflow.defn class ToolWorkflow: """Workflow that manages tool execution with user confirmation and conversation history.""" @@ -39,35 +41,47 @@ class ToolWorkflow: self.confirm: bool = False self.tool_results: List[Dict[str, Any]] = [] - async def _handle_tool_execution(self, current_tool: str, tool_data: ToolData) -> None: + async def _handle_tool_execution( + self, current_tool: str, tool_data: ToolData + ) -> None: """Execute a tool after confirmation and handle its result.""" workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}") - + dynamic_result = await workflow.execute_activity( current_tool, tool_data["args"], schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT, ) dynamic_result["tool"] = current_tool - self.add_message("tool_result", {"tool": current_tool, "result": dynamic_result}) + self.add_message( + "tool_result", {"tool": current_tool, "result": dynamic_result} + ) self.prompt_queue.append( f"### The '{current_tool}' tool completed successfully with {dynamic_result}. " - "INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. " - "DON'T ask any clarifying questions that are outside of the tools and args specified. " + "INSTRUCTIONS: Parse this tool result as plain text, and use the system prompt containing the list of tools in sequence and the conversation history to figure out next steps, if any. " + '{"next": "", "tool": "", "args": {"": "", "": "}, "response": ""}' + "ONLY return those json keys (next, tool, args, response), nothing else." + 'Next should only be "done" if all tools have been run (use the system prompt to figure that out).' + 'Next should be "question" if the tool is not the last one in the sequence.' + 'Next should NOT be "confirm" at this point.' ) - async def _handle_missing_args(self, current_tool: str, args: Dict[str, Any], tool_data: ToolData) -> bool: + async def _handle_missing_args( + self, current_tool: str, args: Dict[str, Any], tool_data: ToolData + ) -> bool: """Check for missing arguments and handle them if found.""" missing_args = [key for key, value in args.items() if value is None] - + if missing_args: self.prompt_queue.append( f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' " f"and following missing arguments for tool {current_tool}: {missing_args}. " "Only provide a valid JSON response without any comments or metadata." ) - workflow.logger.info(f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}") + workflow.logger.info( + f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}" + ) return True return False @@ -76,15 +90,16 @@ class ToolWorkflow: if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE: summary_context, summary_prompt = self.prompt_summary_with_history() summary_input = ToolPromptInput( - prompt=summary_prompt, - context_instructions=summary_context + prompt=summary_prompt, context_instructions=summary_context ) self.conversation_summary = await workflow.start_activity_method( ToolActivities.prompt_llm, summary_input, schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT, ) - workflow.logger.info(f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns.") + workflow.logger.info( + f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns." + ) workflow.continue_as_new( args=[ CombinedInput( @@ -146,14 +161,13 @@ class ToolWorkflow: prompt=prompt, context_instructions=context_instructions, ) - + tool_data = await workflow.execute_activity( ToolActivities.prompt_llm, prompt_input, schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT, retry_policy=RetryPolicy( - maximum_attempts=5, - initial_interval=timedelta(seconds=12) + maximum_attempts=5, initial_interval=timedelta(seconds=15) ), ) self.tool_data = tool_data @@ -219,7 +233,7 @@ class ToolWorkflow: def prompt_with_history(self, prompt: str) -> tuple[str, str]: """Generate a context-aware prompt with conversation history. - + Returns: tuple[str, str]: A tuple of (context_instructions, prompt) """ @@ -234,7 +248,7 @@ class ToolWorkflow: def prompt_summary_with_history(self) -> tuple[str, str]: """Generate a prompt for summarizing the conversation. - + Returns: tuple[str, str]: A tuple of (context_instructions, prompt) """ @@ -248,7 +262,7 @@ class ToolWorkflow: def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None: """Add a message to the conversation history. - + Args: actor: The entity that generated the message (e.g., "user", "agent") response: The message content, either as a string or structured data