mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
fixes to issues 1 2 and 3. Plus tuning
This commit is contained in:
@@ -5,6 +5,11 @@ RAPIDAPI_HOST=sky-scrapper.p.rapidapi.com
|
||||
|
||||
STRIPE_API_KEY=sk_test_51J...
|
||||
|
||||
LLM_PROVIDER=openai # default
|
||||
# or
|
||||
# LLM_PROVIDER=ollama
|
||||
# OLLAMA_MODEL_NAME=qwen2.5:14b
|
||||
|
||||
# uncomment and unset these environment variables to connect to the local dev server
|
||||
# TEMPORAL_ADDRESS=namespace.acct.tmprl.cloud:7233
|
||||
# TEMPORAL_NAMESPACE=default
|
||||
|
||||
29
README.md
29
README.md
@@ -14,7 +14,28 @@ This application uses `.env` files for configuration. Copy the [.env.example](.e
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
The agent requires an OpenAI key for the gpt-4o model. Set this in the `OPENAI_API_KEY` environment variable in .env
|
||||
### LLM Provider Configuration
|
||||
|
||||
The agent can use either OpenAI's GPT-4o or a local LLM via Ollama. Set the `LLM_PROVIDER` environment variable in your `.env` file to choose the desired provider:
|
||||
|
||||
- `LLM_PROVIDER=openai` for OpenAI's GPT-4o
|
||||
- `LLM_PROVIDER=ollama` for the local LLM via Ollama (not recommended for this use case)
|
||||
|
||||
### OpenAI Configuration
|
||||
|
||||
If using OpenAI, ensure you have an OpenAI key for the GPT-4o model. Set this in the `OPENAI_API_KEY` environment variable in `.env`.
|
||||
|
||||
### Ollama Configuration
|
||||
|
||||
To use a local LLM with Ollama:
|
||||
|
||||
1. Install [Ollama](https://ollama.com) and the [Qwen2.5 14B](https://ollama.com/library/qwen2.5) model.
|
||||
- Run `ollama run <OLLAMA_MODEL_NAME>` to start the model. Note that this model is about 9GB to download.
|
||||
- Example: `ollama run qwen2.5:14b`
|
||||
|
||||
2. Set `LLM_PROVIDER=ollama` in your `.env` file and `OLLAMA_MODEL_NAME` to the name of the model you installed.
|
||||
|
||||
Note: The local LLM is disabled by default as ChatGPT 4o was found to be MUCH more reliable for this use case. However, you can switch to Ollama if desired.
|
||||
|
||||
## Agent Tools
|
||||
* Requires a Rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env
|
||||
@@ -85,12 +106,6 @@ Access the UI at `http://localhost:5173`
|
||||
- Note the mapping in `tools/__init__.py` to each tool
|
||||
- See main.py where some tool-specific logic is defined (todo, move this to the tool definition)
|
||||
|
||||
## Using a local LLM instead of ChatGPT 4o
|
||||
With a small code change, the agent can use local LLMs.
|
||||
* Install [Ollama](https://ollama.com) and the [Qwen2.5 14B](https://ollama.com/library/qwen2.5) model (`ollama run qwen2.5:14b`). (note this model is about 9GB to download).
|
||||
* Local LLM is disabled as ChatGPT 4o was better for this use case. To use Ollama, examine `./activities/tool_activities.py` and rename the existing functions.
|
||||
* Note that Qwen2.5 14B is not as good as ChatGPT 4o for this use case and will perform worse at moving the conversation towards the goal.
|
||||
|
||||
## TODO
|
||||
- I should prove this out with other tool definitions outside of the event/flight search case (take advantage of my nice DSL).
|
||||
- Currently hardcoded to the Temporal dev server at localhost:7233. Need to support options incl Temporal Cloud.
|
||||
|
||||
@@ -18,6 +18,14 @@ class ToolPromptInput:
|
||||
class ToolActivities:
|
||||
@activity.defn
|
||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||
|
||||
if llm_provider == "ollama":
|
||||
return self.prompt_llm_ollama(input)
|
||||
else:
|
||||
return self.prompt_llm_openai(input)
|
||||
|
||||
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
|
||||
client = OpenAI(
|
||||
api_key=os.environ.get(
|
||||
"OPENAI_API_KEY"
|
||||
@@ -44,9 +52,8 @@ class ToolActivities:
|
||||
response_content = chat_completion.choices[0].message.content
|
||||
print(f"ChatGPT response: {response_content}")
|
||||
|
||||
# Trim formatting markers if present
|
||||
if response_content.startswith("```json") and response_content.endswith("```"):
|
||||
response_content = response_content[7:-3].strip()
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
try:
|
||||
data = json.loads(response_content)
|
||||
@@ -58,7 +65,7 @@ class ToolActivities:
|
||||
|
||||
@activity.defn
|
||||
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
|
||||
model_name = "qwen2.5:14b"
|
||||
model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -76,8 +83,11 @@ class ToolActivities:
|
||||
|
||||
print(f"Chat response: {response.message.content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response.message.content)
|
||||
|
||||
try:
|
||||
data = json.loads(response.message.content)
|
||||
data = json.loads(response_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
print(response.message.content)
|
||||
@@ -85,6 +95,54 @@ class ToolActivities:
|
||||
|
||||
return data
|
||||
|
||||
def sanitize_json_response(self, response_content: str) -> str:
|
||||
"""
|
||||
Extracts the JSON block from the response content as a string.
|
||||
Supports:
|
||||
- JSON surrounded by ```json and ```
|
||||
- Raw JSON input
|
||||
- JSON preceded or followed by extra text
|
||||
Rejects invalid input that doesn't contain JSON.
|
||||
"""
|
||||
try:
|
||||
start_marker = "```json"
|
||||
end_marker = "```"
|
||||
|
||||
json_str = None
|
||||
|
||||
# Case 1: JSON surrounded by markers
|
||||
if start_marker in response_content and end_marker in response_content:
|
||||
json_start = response_content.index(start_marker) + len(start_marker)
|
||||
json_end = response_content.index(end_marker, json_start)
|
||||
json_str = response_content[json_start:json_end].strip()
|
||||
|
||||
# Case 2: Text with valid JSON
|
||||
else:
|
||||
# Try to locate the JSON block by scanning for the first `{` and last `}`
|
||||
json_start = response_content.find("{")
|
||||
json_end = response_content.rfind("}")
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_start < json_end:
|
||||
json_str = response_content[json_start : json_end + 1].strip()
|
||||
|
||||
# Validate and ensure the extracted JSON is valid
|
||||
if json_str:
|
||||
json.loads(json_str) # This will raise an error if the JSON is invalid
|
||||
return json_str
|
||||
|
||||
# If no valid JSON found, raise an error
|
||||
raise ValueError("Response does not contain valid JSON.")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Invalid JSON
|
||||
print(f"Invalid JSON detected in response: {response_content}")
|
||||
raise ValueError("Response does not contain valid JSON.")
|
||||
except Exception as e:
|
||||
# Other errors
|
||||
print(f"Error processing response: {str(e)}")
|
||||
print(f"Full response: {response_content}")
|
||||
raise
|
||||
|
||||
|
||||
def get_current_date_human_readable():
|
||||
"""
|
||||
|
||||
29
api/main.py
29
api/main.py
@@ -1,14 +1,14 @@
|
||||
from fastapi import FastAPI
|
||||
from typing import Optional
|
||||
from temporalio.client import Client
|
||||
from temporalio.exceptions import TemporalError
|
||||
from temporalio.api.enums.v1 import WorkflowExecutionStatus
|
||||
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||
from tools.goal_registry import goal_event_flight_invoice
|
||||
from temporalio.exceptions import TemporalError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
|
||||
|
||||
app = FastAPI()
|
||||
temporal_client: Optional[Client] = None
|
||||
|
||||
@@ -58,11 +58,32 @@ async def get_conversation_history():
|
||||
"""Calls the workflow's 'get_conversation_history' query."""
|
||||
try:
|
||||
handle = temporal_client.get_workflow_handle("agent-workflow")
|
||||
conversation_history = await handle.query("get_conversation_history")
|
||||
|
||||
status_names = {
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED",
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED",
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED"
|
||||
}
|
||||
|
||||
failed_states = [
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED,
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED,
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED
|
||||
]
|
||||
|
||||
# Check workflow status first
|
||||
description = await handle.describe()
|
||||
if description.status in failed_states:
|
||||
status_name = status_names.get(description.status, "UNKNOWN_STATUS")
|
||||
print(f"Workflow is in {status_name} state. Returning empty history.")
|
||||
return []
|
||||
|
||||
# Only query if workflow is running
|
||||
conversation_history = await handle.query("get_conversation_history")
|
||||
return conversation_history
|
||||
|
||||
except TemporalError as e:
|
||||
print(e)
|
||||
print(f"Temporal error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class ChatErrorBoundary extends React.Component {
|
||||
if (this.state.hasError) {
|
||||
return (
|
||||
<div className="text-red-500 p-4 text-center">
|
||||
Something went wrong. Please refresh the page.
|
||||
Something went wrong. Please Terminate the workflow and try again.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ def find_events(args: dict) -> dict:
|
||||
"dateFrom": event["dateFrom"],
|
||||
"dateTo": event["dateTo"],
|
||||
"description": event["description"],
|
||||
"monthContext": month_context,
|
||||
"month": month_context,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,9 @@ from models.tool_definitions import ToolDefinition, ToolArgument
|
||||
find_events_tool = ToolDefinition(
|
||||
name="FindEvents",
|
||||
description="Find upcoming events to travel to a given city (e.g., 'Melbourne') and a date or month. "
|
||||
"It knows about events in Oceania only (e.g. major Australian and New Zealand cities).",
|
||||
"It knows about events in Oceania only (e.g. major Australian and New Zealand cities). "
|
||||
"It will search 1 month either side of the month provided. "
|
||||
"Returns a list of events. ",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="city",
|
||||
@@ -13,7 +15,7 @@ find_events_tool = ToolDefinition(
|
||||
ToolArgument(
|
||||
name="month",
|
||||
type="string",
|
||||
description="The month or approximate date range to find events",
|
||||
description="The month to search for events (will search 1 month either side of the month provided)",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -20,12 +20,14 @@ Message = Dict[str, Union[str, Dict[str, Any]]]
|
||||
ConversationHistory = Dict[str, List[Message]]
|
||||
NextStep = Literal["confirm", "question", "done"]
|
||||
|
||||
|
||||
class ToolData(TypedDict, total=False):
|
||||
next: NextStep
|
||||
tool: str
|
||||
args: Dict[str, Any]
|
||||
response: str
|
||||
|
||||
|
||||
@workflow.defn
|
||||
class ToolWorkflow:
|
||||
"""Workflow that manages tool execution with user confirmation and conversation history."""
|
||||
@@ -39,7 +41,9 @@ class ToolWorkflow:
|
||||
self.confirm: bool = False
|
||||
self.tool_results: List[Dict[str, Any]] = []
|
||||
|
||||
async def _handle_tool_execution(self, current_tool: str, tool_data: ToolData) -> None:
|
||||
async def _handle_tool_execution(
|
||||
self, current_tool: str, tool_data: ToolData
|
||||
) -> None:
|
||||
"""Execute a tool after confirmation and handle its result."""
|
||||
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
||||
|
||||
@@ -49,15 +53,23 @@ class ToolWorkflow:
|
||||
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
|
||||
)
|
||||
dynamic_result["tool"] = current_tool
|
||||
self.add_message("tool_result", {"tool": current_tool, "result": dynamic_result})
|
||||
self.add_message(
|
||||
"tool_result", {"tool": current_tool, "result": dynamic_result}
|
||||
)
|
||||
|
||||
self.prompt_queue.append(
|
||||
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
||||
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
|
||||
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
|
||||
"INSTRUCTIONS: Parse this tool result as plain text, and use the system prompt containing the list of tools in sequence and the conversation history to figure out next steps, if any. "
|
||||
'{"next": "<question|confirm|done>", "tool": "<tool_name or null>", "args": {"<arg1>": "<value1 or null>", "<arg2>": "<value2 or null>}, "response": "<plain text>"}'
|
||||
"ONLY return those json keys (next, tool, args, response), nothing else."
|
||||
'Next should only be "done" if all tools have been run (use the system prompt to figure that out).'
|
||||
'Next should be "question" if the tool is not the last one in the sequence.'
|
||||
'Next should NOT be "confirm" at this point.'
|
||||
)
|
||||
|
||||
async def _handle_missing_args(self, current_tool: str, args: Dict[str, Any], tool_data: ToolData) -> bool:
|
||||
async def _handle_missing_args(
|
||||
self, current_tool: str, args: Dict[str, Any], tool_data: ToolData
|
||||
) -> bool:
|
||||
"""Check for missing arguments and handle them if found."""
|
||||
missing_args = [key for key, value in args.items() if value is None]
|
||||
|
||||
@@ -67,7 +79,9 @@ class ToolWorkflow:
|
||||
f"and following missing arguments for tool {current_tool}: {missing_args}. "
|
||||
"Only provide a valid JSON response without any comments or metadata."
|
||||
)
|
||||
workflow.logger.info(f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}")
|
||||
workflow.logger.info(
|
||||
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -76,15 +90,16 @@ class ToolWorkflow:
|
||||
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
|
||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
||||
summary_input = ToolPromptInput(
|
||||
prompt=summary_prompt,
|
||||
context_instructions=summary_context
|
||||
prompt=summary_prompt, context_instructions=summary_context
|
||||
)
|
||||
self.conversation_summary = await workflow.start_activity_method(
|
||||
ToolActivities.prompt_llm,
|
||||
summary_input,
|
||||
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
|
||||
)
|
||||
workflow.logger.info(f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns.")
|
||||
workflow.logger.info(
|
||||
f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns."
|
||||
)
|
||||
workflow.continue_as_new(
|
||||
args=[
|
||||
CombinedInput(
|
||||
@@ -152,8 +167,7 @@ class ToolWorkflow:
|
||||
prompt_input,
|
||||
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
|
||||
retry_policy=RetryPolicy(
|
||||
maximum_attempts=5,
|
||||
initial_interval=timedelta(seconds=12)
|
||||
maximum_attempts=5, initial_interval=timedelta(seconds=15)
|
||||
),
|
||||
)
|
||||
self.tool_data = tool_data
|
||||
|
||||
Reference in New Issue
Block a user