fixes to issues 1 2 and 3. Plus tuning

This commit is contained in:
Steve Androulakis
2025-01-24 15:23:57 -08:00
parent caf5812f90
commit 7977894f64
8 changed files with 154 additions and 39 deletions

View File

@@ -20,12 +20,14 @@ Message = Dict[str, Union[str, Dict[str, Any]]]
ConversationHistory = Dict[str, List[Message]]
NextStep = Literal["confirm", "question", "done"]
class ToolData(TypedDict, total=False):
next: NextStep
tool: str
args: Dict[str, Any]
response: str
@workflow.defn
class ToolWorkflow:
"""Workflow that manages tool execution with user confirmation and conversation history."""
@@ -39,35 +41,47 @@ class ToolWorkflow:
self.confirm: bool = False
self.tool_results: List[Dict[str, Any]] = []
async def _handle_tool_execution(self, current_tool: str, tool_data: ToolData) -> None:
async def _handle_tool_execution(
self, current_tool: str, tool_data: ToolData
) -> None:
"""Execute a tool after confirmation and handle its result."""
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
dynamic_result = await workflow.execute_activity(
current_tool,
tool_data["args"],
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
)
dynamic_result["tool"] = current_tool
self.add_message("tool_result", {"tool": current_tool, "result": dynamic_result})
self.add_message(
"tool_result", {"tool": current_tool, "result": dynamic_result}
)
self.prompt_queue.append(
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
"INSTRUCTIONS: Parse this tool result as plain text, and use the system prompt containing the list of tools in sequence and the conversation history to figure out next steps, if any. "
'{"next": "<question|confirm|done>", "tool": "<tool_name or null>", "args": {"<arg1>": "<value1 or null>", "<arg2>": "<value2 or null>}, "response": "<plain text>"}'
"ONLY return those json keys (next, tool, args, response), nothing else."
'Next should only be "done" if all tools have been run (use the system prompt to figure that out).'
'Next should be "question" if the tool is not the last one in the sequence.'
'Next should NOT be "confirm" at this point.'
)
async def _handle_missing_args(self, current_tool: str, args: Dict[str, Any], tool_data: ToolData) -> bool:
async def _handle_missing_args(
self, current_tool: str, args: Dict[str, Any], tool_data: ToolData
) -> bool:
"""Check for missing arguments and handle them if found."""
missing_args = [key for key, value in args.items() if value is None]
if missing_args:
self.prompt_queue.append(
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' "
f"and following missing arguments for tool {current_tool}: {missing_args}. "
"Only provide a valid JSON response without any comments or metadata."
)
workflow.logger.info(f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}")
workflow.logger.info(
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
)
return True
return False
@@ -76,15 +90,16 @@ class ToolWorkflow:
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = ToolPromptInput(
prompt=summary_prompt,
context_instructions=summary_context
prompt=summary_prompt, context_instructions=summary_context
)
self.conversation_summary = await workflow.start_activity_method(
ToolActivities.prompt_llm,
summary_input,
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
)
workflow.logger.info(f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns.")
workflow.logger.info(
f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns."
)
workflow.continue_as_new(
args=[
CombinedInput(
@@ -146,14 +161,13 @@ class ToolWorkflow:
prompt=prompt,
context_instructions=context_instructions,
)
tool_data = await workflow.execute_activity(
ToolActivities.prompt_llm,
prompt_input,
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
retry_policy=RetryPolicy(
maximum_attempts=5,
initial_interval=timedelta(seconds=12)
maximum_attempts=5, initial_interval=timedelta(seconds=15)
),
)
self.tool_data = tool_data
@@ -219,7 +233,7 @@ class ToolWorkflow:
def prompt_with_history(self, prompt: str) -> tuple[str, str]:
"""Generate a context-aware prompt with conversation history.
Returns:
tuple[str, str]: A tuple of (context_instructions, prompt)
"""
@@ -234,7 +248,7 @@ class ToolWorkflow:
def prompt_summary_with_history(self) -> tuple[str, str]:
"""Generate a prompt for summarizing the conversation.
Returns:
tuple[str, str]: A tuple of (context_instructions, prompt)
"""
@@ -248,7 +262,7 @@ class ToolWorkflow:
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
"""Add a message to the conversation history.
Args:
actor: The entity that generated the message (e.g., "user", "agent")
response: The message content, either as a string or structured data