refactor, date context

This commit is contained in:
Steve Androulakis
2025-01-01 13:16:18 -08:00
parent 8115f0d2df
commit e7e8e7e658
17 changed files with 118 additions and 90 deletions

21
scripts/end_chat.py Normal file
View File

@@ -0,0 +1,21 @@
import asyncio
from temporalio.client import Client
from workflows.tool_workflow import ToolWorkflow
async def main():
# Create client connected to server at the given address
client = await Client.connect("localhost:7233")
workflow_id = "ollama-agent"
handle = client.get_workflow_handle_for(ToolWorkflow.run, workflow_id)
# Sends a signal to the workflow
await handle.signal(ToolWorkflow.end_chat)
if __name__ == "__main__":
print("Sending signal to end chat.")
asyncio.run(main())

31
scripts/get_history.py Normal file
View File

@@ -0,0 +1,31 @@
import asyncio
from temporalio.client import Client
from workflows import ToolWorkflow
async def main():
# Create client connected to server at the given address
client = await Client.connect("localhost:7233")
workflow_id = "ollama-agent"
handle = client.get_workflow_handle(workflow_id)
# Queries the workflow for the conversation history
history = await handle.query(ToolWorkflow.get_conversation_history)
print("Conversation History")
print(
*(f"{speaker.title()}: {message}\n" for speaker, message in history), sep="\n"
)
# Queries the workflow for the conversation summary
summary = await handle.query(ToolWorkflow.get_summary_from_history)
if summary is not None:
print("Conversation Summary:")
print(summary)
if __name__ == "__main__":
asyncio.run(main())

23
scripts/run_ollama.py Normal file
View File

@@ -0,0 +1,23 @@
from ollama import chat, ChatResponse
def main():
model_name = "mistral"
# The messages to pass to the model
messages = [
{
"role": "user",
"content": "Why is the sky blue?",
}
]
# Call ollama's chat function
response: ChatResponse = chat(model=model_name, messages=messages)
# Print the full message content
print(response.message.content)
if __name__ == "__main__":
main()

36
scripts/run_worker.py Normal file
View File

@@ -0,0 +1,36 @@
import asyncio
import concurrent.futures
import logging
from temporalio.client import Client
from temporalio.worker import Worker
from activities.tool_activities import ToolActivities
from workflows.tool_workflow import ToolWorkflow
from workflows.parent_workflow import ParentWorkflow
async def main():
# Create client connected to server at the given address
client = await Client.connect("localhost:7233")
activities = ToolActivities()
# Run the worker
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
worker = Worker(
client,
task_queue="ollama-task-queue",
workflows=[ToolWorkflow, ParentWorkflow],
activities=[activities.prompt_llm, activities.parse_tool_data],
activity_executor=activity_executor,
)
await worker.run()
if __name__ == "__main__":
print("Starting worker")
print("Then run 'python send_message.py \"<prompt>\"'")
logging.basicConfig(level=logging.INFO)
asyncio.run(main())

68
scripts/send_message.py Normal file
View File

@@ -0,0 +1,68 @@
import asyncio
import sys
from temporalio.client import Client
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from models.tool_definitions import ToolDefinition, ToolArgument
from workflows.tool_workflow import ToolWorkflow
async def main(prompt):
# Construct your tool definitions in code
search_flights_tool = ToolDefinition(
name="SearchFlights",
description="Search for return flights from an origin to a destination within a date range",
arguments=[
ToolArgument(
name="origin",
type="string",
description="Airport or city (infer airport code from city)",
),
ToolArgument(
name="destination",
type="string",
description="Airport or city code for arrival (infer airport code from city)",
),
ToolArgument(
name="dateDepart",
type="ISO8601",
description="Start of date range in human readable format",
),
ToolArgument(
name="dateReturn",
type="ISO8601",
description="End of date range in human readable format",
),
],
)
# Wrap it in ToolsData
tools_data = ToolsData(tools=[search_flights_tool])
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None), tools_data=tools_data
)
# Create client connected to Temporal server
client = await Client.connect("localhost:7233")
workflow_id = "ollama-agent"
# Start or signal the workflow, passing OllamaParams and tools_data
await client.start_workflow(
ToolWorkflow.run,
combined_input, # or pass custom summary/prompt_queue
id=workflow_id,
task_queue="ollama-task-queue",
start_signal="user_prompt",
start_signal_args=[prompt],
)
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python send_message.py '<prompt>'")
print("Example: python send_message.py 'What animals are marsupials?'")
else:
asyncio.run(main(sys.argv[1]))