mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
fixes to issues 1 2 and 3. Plus tuning
This commit is contained in:
@@ -5,6 +5,11 @@ RAPIDAPI_HOST=sky-scrapper.p.rapidapi.com
|
|||||||
|
|
||||||
STRIPE_API_KEY=sk_test_51J...
|
STRIPE_API_KEY=sk_test_51J...
|
||||||
|
|
||||||
|
LLM_PROVIDER=openai # default
|
||||||
|
# or
|
||||||
|
# LLM_PROVIDER=ollama
|
||||||
|
# OLLAMA_MODEL_NAME=qwen2.5:14b
|
||||||
|
|
||||||
# uncomment and unset these environment variables to connect to the local dev server
|
# uncomment and unset these environment variables to connect to the local dev server
|
||||||
# TEMPORAL_ADDRESS=namespace.acct.tmprl.cloud:7233
|
# TEMPORAL_ADDRESS=namespace.acct.tmprl.cloud:7233
|
||||||
# TEMPORAL_NAMESPACE=default
|
# TEMPORAL_NAMESPACE=default
|
||||||
|
|||||||
29
README.md
29
README.md
@@ -14,7 +14,28 @@ This application uses `.env` files for configuration. Copy the [.env.example](.e
|
|||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
```
|
```
|
||||||
|
|
||||||
The agent requires an OpenAI key for the gpt-4o model. Set this in the `OPENAI_API_KEY` environment variable in .env
|
### LLM Provider Configuration
|
||||||
|
|
||||||
|
The agent can use either OpenAI's GPT-4o or a local LLM via Ollama. Set the `LLM_PROVIDER` environment variable in your `.env` file to choose the desired provider:
|
||||||
|
|
||||||
|
- `LLM_PROVIDER=openai` for OpenAI's GPT-4o
|
||||||
|
- `LLM_PROVIDER=ollama` for the local LLM via Ollama (not recommended for this use case)
|
||||||
|
|
||||||
|
### OpenAI Configuration
|
||||||
|
|
||||||
|
If using OpenAI, ensure you have an OpenAI key for the GPT-4o model. Set this in the `OPENAI_API_KEY` environment variable in `.env`.
|
||||||
|
|
||||||
|
### Ollama Configuration
|
||||||
|
|
||||||
|
To use a local LLM with Ollama:
|
||||||
|
|
||||||
|
1. Install [Ollama](https://ollama.com) and the [Qwen2.5 14B](https://ollama.com/library/qwen2.5) model.
|
||||||
|
- Run `ollama run <OLLAMA_MODEL_NAME>` to start the model. Note that this model is about 9GB to download.
|
||||||
|
- Example: `ollama run qwen2.5:14b`
|
||||||
|
|
||||||
|
2. Set `LLM_PROVIDER=ollama` in your `.env` file and `OLLAMA_MODEL_NAME` to the name of the model you installed.
|
||||||
|
|
||||||
|
Note: The local LLM is disabled by default as ChatGPT 4o was found to be MUCH more reliable for this use case. However, you can switch to Ollama if desired.
|
||||||
|
|
||||||
## Agent Tools
|
## Agent Tools
|
||||||
* Requires a Rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env
|
* Requires a Rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env
|
||||||
@@ -85,12 +106,6 @@ Access the UI at `http://localhost:5173`
|
|||||||
- Note the mapping in `tools/__init__.py` to each tool
|
- Note the mapping in `tools/__init__.py` to each tool
|
||||||
- See main.py where some tool-specific logic is defined (todo, move this to the tool definition)
|
- See main.py where some tool-specific logic is defined (todo, move this to the tool definition)
|
||||||
|
|
||||||
## Using a local LLM instead of ChatGPT 4o
|
|
||||||
With a small code change, the agent can use local LLMs.
|
|
||||||
* Install [Ollama](https://ollama.com) and the [Qwen2.5 14B](https://ollama.com/library/qwen2.5) model (`ollama run qwen2.5:14b`). (note this model is about 9GB to download).
|
|
||||||
* Local LLM is disabled as ChatGPT 4o was better for this use case. To use Ollama, examine `./activities/tool_activities.py` and rename the existing functions.
|
|
||||||
* Note that Qwen2.5 14B is not as good as ChatGPT 4o for this use case and will perform worse at moving the conversation towards the goal.
|
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
- I should prove this out with other tool definitions outside of the event/flight search case (take advantage of my nice DSL).
|
- I should prove this out with other tool definitions outside of the event/flight search case (take advantage of my nice DSL).
|
||||||
- Currently hardcoded to the Temporal dev server at localhost:7233. Need to support options incl Temporal Cloud.
|
- Currently hardcoded to the Temporal dev server at localhost:7233. Need to support options incl Temporal Cloud.
|
||||||
|
|||||||
@@ -18,6 +18,14 @@ class ToolPromptInput:
|
|||||||
class ToolActivities:
|
class ToolActivities:
|
||||||
@activity.defn
|
@activity.defn
|
||||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||||
|
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||||
|
|
||||||
|
if llm_provider == "ollama":
|
||||||
|
return self.prompt_llm_ollama(input)
|
||||||
|
else:
|
||||||
|
return self.prompt_llm_openai(input)
|
||||||
|
|
||||||
|
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
api_key=os.environ.get(
|
api_key=os.environ.get(
|
||||||
"OPENAI_API_KEY"
|
"OPENAI_API_KEY"
|
||||||
@@ -44,9 +52,8 @@ class ToolActivities:
|
|||||||
response_content = chat_completion.choices[0].message.content
|
response_content = chat_completion.choices[0].message.content
|
||||||
print(f"ChatGPT response: {response_content}")
|
print(f"ChatGPT response: {response_content}")
|
||||||
|
|
||||||
# Trim formatting markers if present
|
# Use the new sanitize function
|
||||||
if response_content.startswith("```json") and response_content.endswith("```"):
|
response_content = self.sanitize_json_response(response_content)
|
||||||
response_content = response_content[7:-3].strip()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(response_content)
|
data = json.loads(response_content)
|
||||||
@@ -58,7 +65,7 @@ class ToolActivities:
|
|||||||
|
|
||||||
@activity.defn
|
@activity.defn
|
||||||
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
|
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
|
||||||
model_name = "qwen2.5:14b"
|
model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -76,8 +83,11 @@ class ToolActivities:
|
|||||||
|
|
||||||
print(f"Chat response: {response.message.content}")
|
print(f"Chat response: {response.message.content}")
|
||||||
|
|
||||||
|
# Use the new sanitize function
|
||||||
|
response_content = self.sanitize_json_response(response.message.content)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(response.message.content)
|
data = json.loads(response_content)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"Invalid JSON: {e}")
|
print(f"Invalid JSON: {e}")
|
||||||
print(response.message.content)
|
print(response.message.content)
|
||||||
@@ -85,6 +95,54 @@ class ToolActivities:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def sanitize_json_response(self, response_content: str) -> str:
|
||||||
|
"""
|
||||||
|
Extracts the JSON block from the response content as a string.
|
||||||
|
Supports:
|
||||||
|
- JSON surrounded by ```json and ```
|
||||||
|
- Raw JSON input
|
||||||
|
- JSON preceded or followed by extra text
|
||||||
|
Rejects invalid input that doesn't contain JSON.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
start_marker = "```json"
|
||||||
|
end_marker = "```"
|
||||||
|
|
||||||
|
json_str = None
|
||||||
|
|
||||||
|
# Case 1: JSON surrounded by markers
|
||||||
|
if start_marker in response_content and end_marker in response_content:
|
||||||
|
json_start = response_content.index(start_marker) + len(start_marker)
|
||||||
|
json_end = response_content.index(end_marker, json_start)
|
||||||
|
json_str = response_content[json_start:json_end].strip()
|
||||||
|
|
||||||
|
# Case 2: Text with valid JSON
|
||||||
|
else:
|
||||||
|
# Try to locate the JSON block by scanning for the first `{` and last `}`
|
||||||
|
json_start = response_content.find("{")
|
||||||
|
json_end = response_content.rfind("}")
|
||||||
|
|
||||||
|
if json_start != -1 and json_end != -1 and json_start < json_end:
|
||||||
|
json_str = response_content[json_start : json_end + 1].strip()
|
||||||
|
|
||||||
|
# Validate and ensure the extracted JSON is valid
|
||||||
|
if json_str:
|
||||||
|
json.loads(json_str) # This will raise an error if the JSON is invalid
|
||||||
|
return json_str
|
||||||
|
|
||||||
|
# If no valid JSON found, raise an error
|
||||||
|
raise ValueError("Response does not contain valid JSON.")
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Invalid JSON
|
||||||
|
print(f"Invalid JSON detected in response: {response_content}")
|
||||||
|
raise ValueError("Response does not contain valid JSON.")
|
||||||
|
except Exception as e:
|
||||||
|
# Other errors
|
||||||
|
print(f"Error processing response: {str(e)}")
|
||||||
|
print(f"Full response: {response_content}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_current_date_human_readable():
|
def get_current_date_human_readable():
|
||||||
"""
|
"""
|
||||||
|
|||||||
29
api/main.py
29
api/main.py
@@ -1,14 +1,14 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from temporalio.client import Client
|
from temporalio.client import Client
|
||||||
|
from temporalio.exceptions import TemporalError
|
||||||
|
from temporalio.api.enums.v1 import WorkflowExecutionStatus
|
||||||
|
|
||||||
from workflows.tool_workflow import ToolWorkflow
|
from workflows.tool_workflow import ToolWorkflow
|
||||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||||
from tools.goal_registry import goal_event_flight_invoice
|
from tools.goal_registry import goal_event_flight_invoice
|
||||||
from temporalio.exceptions import TemporalError
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
|
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
temporal_client: Optional[Client] = None
|
temporal_client: Optional[Client] = None
|
||||||
|
|
||||||
@@ -58,11 +58,32 @@ async def get_conversation_history():
|
|||||||
"""Calls the workflow's 'get_conversation_history' query."""
|
"""Calls the workflow's 'get_conversation_history' query."""
|
||||||
try:
|
try:
|
||||||
handle = temporal_client.get_workflow_handle("agent-workflow")
|
handle = temporal_client.get_workflow_handle("agent-workflow")
|
||||||
conversation_history = await handle.query("get_conversation_history")
|
|
||||||
|
|
||||||
|
status_names = {
|
||||||
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED",
|
||||||
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED",
|
||||||
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED"
|
||||||
|
}
|
||||||
|
|
||||||
|
failed_states = [
|
||||||
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED,
|
||||||
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED,
|
||||||
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check workflow status first
|
||||||
|
description = await handle.describe()
|
||||||
|
if description.status in failed_states:
|
||||||
|
status_name = status_names.get(description.status, "UNKNOWN_STATUS")
|
||||||
|
print(f"Workflow is in {status_name} state. Returning empty history.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Only query if workflow is running
|
||||||
|
conversation_history = await handle.query("get_conversation_history")
|
||||||
return conversation_history
|
return conversation_history
|
||||||
|
|
||||||
except TemporalError as e:
|
except TemporalError as e:
|
||||||
print(e)
|
print(f"Temporal error: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class ChatErrorBoundary extends React.Component {
|
|||||||
if (this.state.hasError) {
|
if (this.state.hasError) {
|
||||||
return (
|
return (
|
||||||
<div className="text-red-500 p-4 text-center">
|
<div className="text-red-500 p-4 text-center">
|
||||||
Something went wrong. Please refresh the page.
|
Something went wrong. Please Terminate the workflow and try again.
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def find_events(args: dict) -> dict:
|
|||||||
"dateFrom": event["dateFrom"],
|
"dateFrom": event["dateFrom"],
|
||||||
"dateTo": event["dateTo"],
|
"dateTo": event["dateTo"],
|
||||||
"description": event["description"],
|
"description": event["description"],
|
||||||
"monthContext": month_context,
|
"month": month_context,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,17 +3,19 @@ from models.tool_definitions import ToolDefinition, ToolArgument
|
|||||||
find_events_tool = ToolDefinition(
|
find_events_tool = ToolDefinition(
|
||||||
name="FindEvents",
|
name="FindEvents",
|
||||||
description="Find upcoming events to travel to a given city (e.g., 'Melbourne') and a date or month. "
|
description="Find upcoming events to travel to a given city (e.g., 'Melbourne') and a date or month. "
|
||||||
"It knows about events in Oceania only (e.g. major Australian and New Zealand cities).",
|
"It knows about events in Oceania only (e.g. major Australian and New Zealand cities). "
|
||||||
|
"It will search 1 month either side of the month provided. "
|
||||||
|
"Returns a list of events. ",
|
||||||
arguments=[
|
arguments=[
|
||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="city",
|
name="city",
|
||||||
type="string",
|
type="string",
|
||||||
description="Which city to search for events",
|
description="Which city to search for events",
|
||||||
),
|
),
|
||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="month",
|
name="month",
|
||||||
type="string",
|
type="string",
|
||||||
description="The month or approximate date range to find events",
|
description="The month to search for events (will search 1 month either side of the month provided)",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ Message = Dict[str, Union[str, Dict[str, Any]]]
|
|||||||
ConversationHistory = Dict[str, List[Message]]
|
ConversationHistory = Dict[str, List[Message]]
|
||||||
NextStep = Literal["confirm", "question", "done"]
|
NextStep = Literal["confirm", "question", "done"]
|
||||||
|
|
||||||
|
|
||||||
class ToolData(TypedDict, total=False):
|
class ToolData(TypedDict, total=False):
|
||||||
next: NextStep
|
next: NextStep
|
||||||
tool: str
|
tool: str
|
||||||
args: Dict[str, Any]
|
args: Dict[str, Any]
|
||||||
response: str
|
response: str
|
||||||
|
|
||||||
|
|
||||||
@workflow.defn
|
@workflow.defn
|
||||||
class ToolWorkflow:
|
class ToolWorkflow:
|
||||||
"""Workflow that manages tool execution with user confirmation and conversation history."""
|
"""Workflow that manages tool execution with user confirmation and conversation history."""
|
||||||
@@ -39,7 +41,9 @@ class ToolWorkflow:
|
|||||||
self.confirm: bool = False
|
self.confirm: bool = False
|
||||||
self.tool_results: List[Dict[str, Any]] = []
|
self.tool_results: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
async def _handle_tool_execution(self, current_tool: str, tool_data: ToolData) -> None:
|
async def _handle_tool_execution(
|
||||||
|
self, current_tool: str, tool_data: ToolData
|
||||||
|
) -> None:
|
||||||
"""Execute a tool after confirmation and handle its result."""
|
"""Execute a tool after confirmation and handle its result."""
|
||||||
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
||||||
|
|
||||||
@@ -49,15 +53,23 @@ class ToolWorkflow:
|
|||||||
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
|
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
|
||||||
)
|
)
|
||||||
dynamic_result["tool"] = current_tool
|
dynamic_result["tool"] = current_tool
|
||||||
self.add_message("tool_result", {"tool": current_tool, "result": dynamic_result})
|
self.add_message(
|
||||||
|
"tool_result", {"tool": current_tool, "result": dynamic_result}
|
||||||
|
)
|
||||||
|
|
||||||
self.prompt_queue.append(
|
self.prompt_queue.append(
|
||||||
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
||||||
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
|
"INSTRUCTIONS: Parse this tool result as plain text, and use the system prompt containing the list of tools in sequence and the conversation history to figure out next steps, if any. "
|
||||||
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
|
'{"next": "<question|confirm|done>", "tool": "<tool_name or null>", "args": {"<arg1>": "<value1 or null>", "<arg2>": "<value2 or null>}, "response": "<plain text>"}'
|
||||||
|
"ONLY return those json keys (next, tool, args, response), nothing else."
|
||||||
|
'Next should only be "done" if all tools have been run (use the system prompt to figure that out).'
|
||||||
|
'Next should be "question" if the tool is not the last one in the sequence.'
|
||||||
|
'Next should NOT be "confirm" at this point.'
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_missing_args(self, current_tool: str, args: Dict[str, Any], tool_data: ToolData) -> bool:
|
async def _handle_missing_args(
|
||||||
|
self, current_tool: str, args: Dict[str, Any], tool_data: ToolData
|
||||||
|
) -> bool:
|
||||||
"""Check for missing arguments and handle them if found."""
|
"""Check for missing arguments and handle them if found."""
|
||||||
missing_args = [key for key, value in args.items() if value is None]
|
missing_args = [key for key, value in args.items() if value is None]
|
||||||
|
|
||||||
@@ -67,7 +79,9 @@ class ToolWorkflow:
|
|||||||
f"and following missing arguments for tool {current_tool}: {missing_args}. "
|
f"and following missing arguments for tool {current_tool}: {missing_args}. "
|
||||||
"Only provide a valid JSON response without any comments or metadata."
|
"Only provide a valid JSON response without any comments or metadata."
|
||||||
)
|
)
|
||||||
workflow.logger.info(f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}")
|
workflow.logger.info(
|
||||||
|
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -76,15 +90,16 @@ class ToolWorkflow:
|
|||||||
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
|
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
|
||||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
summary_context, summary_prompt = self.prompt_summary_with_history()
|
||||||
summary_input = ToolPromptInput(
|
summary_input = ToolPromptInput(
|
||||||
prompt=summary_prompt,
|
prompt=summary_prompt, context_instructions=summary_context
|
||||||
context_instructions=summary_context
|
|
||||||
)
|
)
|
||||||
self.conversation_summary = await workflow.start_activity_method(
|
self.conversation_summary = await workflow.start_activity_method(
|
||||||
ToolActivities.prompt_llm,
|
ToolActivities.prompt_llm,
|
||||||
summary_input,
|
summary_input,
|
||||||
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
|
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
|
||||||
)
|
)
|
||||||
workflow.logger.info(f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns.")
|
workflow.logger.info(
|
||||||
|
f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns."
|
||||||
|
)
|
||||||
workflow.continue_as_new(
|
workflow.continue_as_new(
|
||||||
args=[
|
args=[
|
||||||
CombinedInput(
|
CombinedInput(
|
||||||
@@ -152,8 +167,7 @@ class ToolWorkflow:
|
|||||||
prompt_input,
|
prompt_input,
|
||||||
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
|
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
|
||||||
retry_policy=RetryPolicy(
|
retry_policy=RetryPolicy(
|
||||||
maximum_attempts=5,
|
maximum_attempts=5, initial_interval=timedelta(seconds=15)
|
||||||
initial_interval=timedelta(seconds=12)
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.tool_data = tool_data
|
self.tool_data = tool_data
|
||||||
|
|||||||
Reference in New Issue
Block a user