mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
refactor, date context
This commit is contained in:
21
scripts/end_chat.py
Normal file
21
scripts/end_chat.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import asyncio
|
||||
|
||||
from temporalio.client import Client
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
|
||||
|
||||
async def main():
|
||||
# Create client connected to server at the given address
|
||||
client = await Client.connect("localhost:7233")
|
||||
|
||||
workflow_id = "ollama-agent"
|
||||
|
||||
handle = client.get_workflow_handle_for(ToolWorkflow.run, workflow_id)
|
||||
|
||||
# Sends a signal to the workflow
|
||||
await handle.signal(ToolWorkflow.end_chat)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Sending signal to end chat.")
|
||||
asyncio.run(main())
|
||||
31
scripts/get_history.py
Normal file
31
scripts/get_history.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import asyncio
|
||||
|
||||
from temporalio.client import Client
|
||||
from workflows import ToolWorkflow
|
||||
|
||||
|
||||
async def main():
|
||||
# Create client connected to server at the given address
|
||||
client = await Client.connect("localhost:7233")
|
||||
workflow_id = "ollama-agent"
|
||||
|
||||
handle = client.get_workflow_handle(workflow_id)
|
||||
|
||||
# Queries the workflow for the conversation history
|
||||
history = await handle.query(ToolWorkflow.get_conversation_history)
|
||||
|
||||
print("Conversation History")
|
||||
print(
|
||||
*(f"{speaker.title()}: {message}\n" for speaker, message in history), sep="\n"
|
||||
)
|
||||
|
||||
# Queries the workflow for the conversation summary
|
||||
summary = await handle.query(ToolWorkflow.get_summary_from_history)
|
||||
|
||||
if summary is not None:
|
||||
print("Conversation Summary:")
|
||||
print(summary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
23
scripts/run_ollama.py
Normal file
23
scripts/run_ollama.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from ollama import chat, ChatResponse
|
||||
|
||||
|
||||
def main():
|
||||
model_name = "mistral"
|
||||
|
||||
# The messages to pass to the model
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Why is the sky blue?",
|
||||
}
|
||||
]
|
||||
|
||||
# Call ollama's chat function
|
||||
response: ChatResponse = chat(model=model_name, messages=messages)
|
||||
|
||||
# Print the full message content
|
||||
print(response.message.content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
36
scripts/run_worker.py
Normal file
36
scripts/run_worker.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
|
||||
from temporalio.client import Client
|
||||
from temporalio.worker import Worker
|
||||
|
||||
from activities.tool_activities import ToolActivities
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
from workflows.parent_workflow import ParentWorkflow
|
||||
|
||||
|
||||
async def main():
|
||||
# Create client connected to server at the given address
|
||||
client = await Client.connect("localhost:7233")
|
||||
activities = ToolActivities()
|
||||
|
||||
# Run the worker
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
||||
worker = Worker(
|
||||
client,
|
||||
task_queue="ollama-task-queue",
|
||||
workflows=[ToolWorkflow, ParentWorkflow],
|
||||
activities=[activities.prompt_llm, activities.parse_tool_data],
|
||||
activity_executor=activity_executor,
|
||||
)
|
||||
await worker.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting worker")
|
||||
print("Then run 'python send_message.py \"<prompt>\"'")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
asyncio.run(main())
|
||||
68
scripts/send_message.py
Normal file
68
scripts/send_message.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from temporalio.client import Client
|
||||
|
||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||
from models.tool_definitions import ToolDefinition, ToolArgument
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
|
||||
|
||||
async def main(prompt):
|
||||
# Construct your tool definitions in code
|
||||
search_flights_tool = ToolDefinition(
|
||||
name="SearchFlights",
|
||||
description="Search for return flights from an origin to a destination within a date range",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="origin",
|
||||
type="string",
|
||||
description="Airport or city (infer airport code from city)",
|
||||
),
|
||||
ToolArgument(
|
||||
name="destination",
|
||||
type="string",
|
||||
description="Airport or city code for arrival (infer airport code from city)",
|
||||
),
|
||||
ToolArgument(
|
||||
name="dateDepart",
|
||||
type="ISO8601",
|
||||
description="Start of date range in human readable format",
|
||||
),
|
||||
ToolArgument(
|
||||
name="dateReturn",
|
||||
type="ISO8601",
|
||||
description="End of date range in human readable format",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Wrap it in ToolsData
|
||||
tools_data = ToolsData(tools=[search_flights_tool])
|
||||
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None), tools_data=tools_data
|
||||
)
|
||||
|
||||
# Create client connected to Temporal server
|
||||
client = await Client.connect("localhost:7233")
|
||||
|
||||
workflow_id = "ollama-agent"
|
||||
|
||||
# Start or signal the workflow, passing OllamaParams and tools_data
|
||||
await client.start_workflow(
|
||||
ToolWorkflow.run,
|
||||
combined_input, # or pass custom summary/prompt_queue
|
||||
id=workflow_id,
|
||||
task_queue="ollama-task-queue",
|
||||
start_signal="user_prompt",
|
||||
start_signal_args=[prompt],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python send_message.py '<prompt>'")
|
||||
print("Example: python send_message.py 'What animals are marsupials?'")
|
||||
else:
|
||||
asyncio.run(main(sys.argv[1]))
|
||||
Reference in New Issue
Block a user