mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
prompt engineering, train_api date parsing changes
This commit is contained in:
23
api/main.py
23
api/main.py
@@ -6,17 +6,20 @@ from temporalio.api.enums.v1 import WorkflowExecutionStatus
|
||||
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||
from tools.goal_registry import goal_event_flight_invoice
|
||||
from tools.goal_registry import goal_match_train_invoice
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
|
||||
|
||||
app = FastAPI()
|
||||
temporal_client: Optional[Client] = None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global temporal_client
|
||||
temporal_client = await get_temporal_client()
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173"],
|
||||
@@ -62,13 +65,13 @@ async def get_conversation_history():
|
||||
status_names = {
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED",
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED",
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED"
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED",
|
||||
}
|
||||
|
||||
failed_states = [
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED,
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED,
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED,
|
||||
]
|
||||
|
||||
# Check workflow status first
|
||||
@@ -77,11 +80,11 @@ async def get_conversation_history():
|
||||
status_name = status_names.get(description.status, "UNKNOWN_STATUS")
|
||||
print(f"Workflow is in {status_name} state. Returning empty history.")
|
||||
return []
|
||||
|
||||
|
||||
# Only query if workflow is running
|
||||
conversation_history = await handle.query("get_conversation_history")
|
||||
return conversation_history
|
||||
|
||||
|
||||
except TemporalError as e:
|
||||
print(f"Temporal error: {e}")
|
||||
return []
|
||||
@@ -92,7 +95,7 @@ async def send_prompt(prompt: str):
|
||||
# Create combined input
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None),
|
||||
agent_goal=goal_event_flight_invoice,
|
||||
agent_goal=goal_match_train_invoice,
|
||||
)
|
||||
|
||||
workflow_id = "agent-workflow"
|
||||
@@ -139,7 +142,7 @@ async def start_workflow():
|
||||
# Create combined input
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None),
|
||||
agent_goal=goal_event_flight_invoice,
|
||||
agent_goal=goal_match_train_invoice,
|
||||
)
|
||||
|
||||
workflow_id = "agent-workflow"
|
||||
@@ -151,7 +154,9 @@ async def start_workflow():
|
||||
id=workflow_id,
|
||||
task_queue=TEMPORAL_TASK_QUEUE,
|
||||
start_signal="user_prompt",
|
||||
start_signal_args=["### " + goal_event_flight_invoice.starter_prompt],
|
||||
start_signal_args=["### " + goal_match_train_invoice.starter_prompt],
|
||||
)
|
||||
|
||||
return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."}
|
||||
return {
|
||||
"message": f"Workflow started with goal's starter prompt: {goal_match_train_invoice.starter_prompt}."
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user