prompt engineering, train_api date parsing changes

This commit is contained in:
Steve Androulakis
2025-02-11 09:35:40 -08:00
parent 7f6ff2397f
commit aeffe75a0a
5 changed files with 141 additions and 61 deletions

View File

@@ -6,17 +6,20 @@ from temporalio.api.enums.v1 import WorkflowExecutionStatus
from workflows.tool_workflow import ToolWorkflow
from models.data_types import CombinedInput, ToolWorkflowParams
from tools.goal_registry import goal_event_flight_invoice
from tools.goal_registry import goal_match_train_invoice
from fastapi.middleware.cors import CORSMiddleware
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
app = FastAPI()
temporal_client: Optional[Client] = None
@app.on_event("startup")
async def startup_event():
global temporal_client
temporal_client = await get_temporal_client()
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173"],
@@ -62,13 +65,13 @@ async def get_conversation_history():
status_names = {
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED",
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED",
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED"
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED",
}
failed_states = [
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED,
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED,
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED,
]
# Check workflow status first
@@ -77,11 +80,11 @@ async def get_conversation_history():
status_name = status_names.get(description.status, "UNKNOWN_STATUS")
print(f"Workflow is in {status_name} state. Returning empty history.")
return []
# Only query if workflow is running
conversation_history = await handle.query("get_conversation_history")
return conversation_history
except TemporalError as e:
print(f"Temporal error: {e}")
return []
@@ -92,7 +95,7 @@ async def send_prompt(prompt: str):
# Create combined input
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None),
agent_goal=goal_event_flight_invoice,
agent_goal=goal_match_train_invoice,
)
workflow_id = "agent-workflow"
@@ -139,7 +142,7 @@ async def start_workflow():
# Create combined input
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None),
agent_goal=goal_event_flight_invoice,
agent_goal=goal_match_train_invoice,
)
workflow_id = "agent-workflow"
@@ -151,7 +154,9 @@ async def start_workflow():
id=workflow_id,
task_queue=TEMPORAL_TASK_QUEUE,
start_signal="user_prompt",
start_signal_args=["### " + goal_event_flight_invoice.starter_prompt],
start_signal_args=["### " + goal_match_train_invoice.starter_prompt],
)
return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."}
return {
"message": f"Workflow started with goal's starter prompt: {goal_match_train_invoice.starter_prompt}."
}