Add dotnet activities for train interface.

This commit is contained in:
Rob Holland
2025-02-14 11:51:57 +00:00
parent 39462955eb
commit e085f02128
12 changed files with 157 additions and 23 deletions

2
enterprise/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
obj
bin

View File

@@ -0,0 +1,40 @@
using System.Net.Http.Json;
using Temporalio.Activities;
using TrainSearchWorker.Models;
namespace TrainSearchWorker.Activities;
public class TrainActivities
{
private readonly HttpClient _client;
public TrainActivities(IHttpClientFactory clientFactory)
{
_client = clientFactory.CreateClient("TrainApi");
}
[Activity]
public async Task<List<Journey>> SearchTrains(SearchTrainsRequest request)
{
var response = await _client.GetAsync(
$"api/search?from={Uri.EscapeDataString(request.From)}" +
$"&to={Uri.EscapeDataString(request.To)}" +
$"&outbound_time={Uri.EscapeDataString(request.OutboundTime)}" +
$"&return_time={Uri.EscapeDataString(request.ReturnTime)}");
response.EnsureSuccessStatusCode();
return await response.Content.ReadFromJsonAsync<List<Journey>>()
?? throw new InvalidOperationException("Received null response from API");
}
[Activity]
public async Task<List<Journey>> BookTrains(BookTrainsRequest request)
{
var response = await _client.PostAsJsonAsync("api/book", request);
response.EnsureSuccessStatusCode();
return await response.Content.ReadFromJsonAsync<List<Journey>>()
?? throw new InvalidOperationException("Received null response from API");
}
}

View File

@@ -0,0 +1,6 @@
namespace TrainSearchWorker.Models;
public record BookTrainsRequest
{
public required string TrainIds { get; init; }
}

View File

@@ -0,0 +1,12 @@
namespace TrainSearchWorker.Models;
public record Journey
{
public required string Id { get; init; }
public required string Type { get; init; }
public required string Departure { get; init; }
public required string Arrival { get; init; }
public required string DepartureTime { get; init; }
public required string ArrivalTime { get; init; }
public required decimal Price { get; init; }
}

View File

@@ -0,0 +1,9 @@
namespace TrainSearchWorker.Models;
public record SearchTrainsRequest
{
public required string From { get; init; }
public required string To { get; init; }
public required string OutboundTime { get; init; }
public required string ReturnTime { get; init; }
}

52
enterprise/Program.cs Normal file
View File

@@ -0,0 +1,52 @@
using Microsoft.Extensions.DependencyInjection;
using Temporalio.Client;
using Temporalio.Worker;
using TrainSearchWorker.Activities;
var services = new ServiceCollection();
// Add HTTP client
services.AddHttpClient("TrainApi", client =>
{
client.BaseAddress = new Uri("http://localhost:8080/");
client.DefaultRequestHeaders.Add("Accept", "application/json");
});
// Add activities
services.AddScoped<TrainActivities>();
var serviceProvider = services.BuildServiceProvider();
// Create client
var client = await TemporalClient.ConnectAsync(new()
{
TargetHost = "localhost:7233",
});
// Create worker options
var options = new TemporalWorkerOptions("agent-task-queue-legacy");
// Register activities
var activities = serviceProvider.GetRequiredService<TrainActivities>();
options.AddActivity(activities.SearchTrains);
options.AddActivity(activities.BookTrains);
// Create and run worker
var worker = new TemporalWorker(client, options);
Console.WriteLine("Starting worker...");
using var tokenSource = new CancellationTokenSource();
Console.CancelKeyPress += (_, eventArgs) =>
{
eventArgs.Cancel = true;
tokenSource.Cancel();
};
try
{
await worker.ExecuteAsync(tokenSource.Token);
}
catch (OperationCanceledException)
{
Console.WriteLine("Worker shutting down...");
}

View File

@@ -0,0 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Temporalio" Version="1.0.0" />
<PackageReference Include="Microsoft.Extensions.Http" Version="8.0.0" />
</ItemGroup>
</Project>

View File

@@ -151,7 +151,7 @@ class TrainServer(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
parsed_url = urlparse(self.path) parsed_url = urlparse(self.path)
if parsed_url.path == "/api/journeys": if parsed_url.path == "/api/search":
try: try:
params = parse_qs(parsed_url.query) params = parse_qs(parsed_url.query)
origin = params.get("from", [""])[0] origin = params.get("from", [""])[0]
@@ -212,7 +212,7 @@ class TrainServer(BaseHTTPRequestHandler):
parsed_url = urlparse(self.path) parsed_url = urlparse(self.path)
if parsed_url.path.startswith("/api/book/"): if parsed_url.path.startswith("/api/book/"):
journey_id = parsed_url.path.split("/")[-1] train_ids = parsed_url.path.split("/")[-1].split(",")
self.send_response(200) self.send_response(200)
self.send_header("Content-Type", "application/json") self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
@@ -224,7 +224,7 @@ class TrainServer(BaseHTTPRequestHandler):
response = self.format_json( response = self.format_json(
{ {
"booking_reference": booking_ref, "booking_reference": booking_ref,
"journey_id": journey_id, "train_ids": train_ids,
"status": "confirmed", "status": "confirmed",
} }
) )

View File

@@ -1,7 +1,7 @@
from .search_fixtures import search_fixtures from .search_fixtures import search_fixtures
from .search_flights import search_flights from .search_flights import search_flights
from .search_trains import search_trains from .search_trains import search_trains
from .search_trains import book_train from .search_trains import book_trains
from .create_invoice import create_invoice from .create_invoice import create_invoice
@@ -12,8 +12,8 @@ def get_handler(tool_name: str):
return search_flights return search_flights
if tool_name == "SearchTrains": if tool_name == "SearchTrains":
return search_trains return search_trains
if tool_name == "BookTrain": if tool_name == "BookTrains":
return book_train return book_trains
if tool_name == "CreateInvoice": if tool_name == "CreateInvoice":
return create_invoice return create_invoice

View File

@@ -3,7 +3,7 @@ from tools.tool_registry import (
search_fixtures_tool, search_fixtures_tool,
search_flights_tool, search_flights_tool,
search_trains_tool, search_trains_tool,
book_train_tool, book_trains_tool,
create_invoice_tool, create_invoice_tool,
) )
@@ -11,13 +11,13 @@ goal_match_train_invoice = AgentGoal(
tools=[ tools=[
search_fixtures_tool, search_fixtures_tool,
search_trains_tool, search_trains_tool,
book_train_tool, book_trains_tool,
create_invoice_tool, create_invoice_tool,
], ],
description="Help the user book trains to a premier league match. The user lives in London. Gather args for these tools in order: " description="Help the user book trains to a premier league match. The user lives in London. Gather args for these tools in order: "
"1. SearchFixtures: Search for fixtures for a team in a given month" "1. SearchFixtures: Search for fixtures for a team in a given month"
"2. SearchTrains: Search for trains to the city of the match and list them for the customer to choose from" "2. SearchTrains: Search for trains to the city of the match and list them for the customer to choose from"
"3. BookTrain: Book the train tickets" "3. BookTrains: Book the train tickets"
"4. CreateInvoice: Proactively offer to create a simple invoice for the cost of the flights and train tickets", "4. CreateInvoice: Proactively offer to create a simple invoice for the cost of the flights and train tickets",
starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job", starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job",
example_conversation_history="\n ".join( example_conversation_history="\n ".join(
@@ -32,7 +32,7 @@ goal_match_train_invoice = AgentGoal(
"user_confirmed_tool_run: <user clicks confirm on SearchTrains tool>", "user_confirmed_tool_run: <user clicks confirm on SearchTrains tool>",
"tool_result: <results including train dates and times, origin and depature stations>", "tool_result: <results including train dates and times, origin and depature stations>",
"agent: Found some trains! <agent provides a human-readable list of train options>", "agent: Found some trains! <agent provides a human-readable list of train options>",
"user_confirmed_tool_run: <user clicks confirm on BookTrain tool>", "user_confirmed_tool_run: <user clicks confirm on BookTrains tool>",
'tool_result: results including {"status": "success"}', 'tool_result: results including {"status": "success"}',
"agent: Train tickets booked! Please confirm the following invoice for the journey. <agent infers total amount for the invoice and details from the conversation history>", "agent: Train tickets booked! Please confirm the following invoice for the journey. <agent infers total amount for the invoice and details from the conversation history>",
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool which includes details of the train journey, the match, and the total cost>", "user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool which includes details of the train journey, the match, and the total cost>",

View File

@@ -15,7 +15,7 @@ def search_trains(args: dict) -> dict:
if not origin or not destination or not outbound_time or not return_time: if not origin or not destination or not outbound_time or not return_time:
return {"error": "Origin, destination, outbound_time and return_time are required."} return {"error": "Origin, destination, outbound_time and return_time are required."}
search_url = f'{BASE_URL}/api/journeys' search_url = f'{BASE_URL}/api/search'
params = { params = {
'from': origin, 'from': origin,
'to': destination, 'to': destination,
@@ -31,15 +31,15 @@ def search_trains(args: dict) -> dict:
journey_data = response.json() journey_data = response.json()
return journey_data return journey_data
def book_train(args: dict) -> dict: def book_trains(args: dict) -> dict:
load_dotenv(override=True) load_dotenv(override=True)
journey_id = args.get("journey_id") train_ids = args.get("train_ids")
if not journey_id: if not train_ids:
return {"error": "Journey ID is required."} return {"error": "Train IDs is required."}
book_url = f'{BASE_URL}/api/book/{journey_id}' book_url = f'{BASE_URL}/api/book/{train_ids}'
response = requests.post(book_url) response = requests.post(book_url)
if response.status_code != 200: if response.status_code != 200:
return {"error": "Failed to book ticket."} return {"error": "Failed to book ticket."}
@@ -59,7 +59,7 @@ if __name__ == "__main__":
print(search_results) print(search_results)
book_args = { book_args = {
"journey_id": "12345", "train_ids": "12345",
} }
booking_results = book_train(book_args) booking_results = book_trains(book_args)
print(booking_results) print(booking_results)

View File

@@ -54,14 +54,14 @@ search_trains_tool = ToolDefinition(
], ],
) )
book_train_tool = ToolDefinition( book_trains_tool = ToolDefinition(
name="BookTrain", name="BookTrains",
description="Books a train ticket. Returns a booking reference.", description="Books train tickets. Returns a booking reference.",
arguments=[ arguments=[
ToolArgument( ToolArgument(
name="journey_id", name="train_ids",
type="string", type="string",
description="The ID of the journey to book", description="The IDs of the trains to book, comma separated",
), ),
], ],
) )