diff --git a/enterprise/.gitignore b/enterprise/.gitignore new file mode 100644 index 0000000..7de5508 --- /dev/null +++ b/enterprise/.gitignore @@ -0,0 +1,2 @@ +obj +bin diff --git a/enterprise/Activities/TrainActivities.cs b/enterprise/Activities/TrainActivities.cs new file mode 100644 index 0000000..40de38e --- /dev/null +++ b/enterprise/Activities/TrainActivities.cs @@ -0,0 +1,40 @@ +using System.Net.Http.Json; +using Temporalio.Activities; +using TrainSearchWorker.Models; + +namespace TrainSearchWorker.Activities; + +public class TrainActivities +{ + private readonly HttpClient _client; + + public TrainActivities(IHttpClientFactory clientFactory) + { + _client = clientFactory.CreateClient("TrainApi"); + } + + [Activity] + public async Task> SearchTrains(SearchTrainsRequest request) + { + var response = await _client.GetAsync( + $"api/search?from={Uri.EscapeDataString(request.From)}" + + $"&to={Uri.EscapeDataString(request.To)}" + + $"&outbound_time={Uri.EscapeDataString(request.OutboundTime)}" + + $"&return_time={Uri.EscapeDataString(request.ReturnTime)}"); + + response.EnsureSuccessStatusCode(); + + return await response.Content.ReadFromJsonAsync>() + ?? throw new InvalidOperationException("Received null response from API"); + } + + [Activity] + public async Task> BookTrains(BookTrainsRequest request) + { + var response = await _client.PostAsJsonAsync("api/book", request); + response.EnsureSuccessStatusCode(); + + return await response.Content.ReadFromJsonAsync>() + ?? throw new InvalidOperationException("Received null response from API"); + } +} diff --git a/enterprise/Models/BookTrainRequest.cs b/enterprise/Models/BookTrainRequest.cs new file mode 100644 index 0000000..a89f966 --- /dev/null +++ b/enterprise/Models/BookTrainRequest.cs @@ -0,0 +1,6 @@ +namespace TrainSearchWorker.Models; + +public record BookTrainsRequest +{ + public required string TrainIds { get; init; } +} diff --git a/enterprise/Models/Journey.cs b/enterprise/Models/Journey.cs new file mode 100644 index 0000000..1038abd --- /dev/null +++ b/enterprise/Models/Journey.cs @@ -0,0 +1,12 @@ +namespace TrainSearchWorker.Models; + +public record Journey +{ + public required string Id { get; init; } + public required string Type { get; init; } + public required string Departure { get; init; } + public required string Arrival { get; init; } + public required string DepartureTime { get; init; } + public required string ArrivalTime { get; init; } + public required decimal Price { get; init; } +} diff --git a/enterprise/Models/SearchTrainsRequest.cs b/enterprise/Models/SearchTrainsRequest.cs new file mode 100644 index 0000000..a59a21d --- /dev/null +++ b/enterprise/Models/SearchTrainsRequest.cs @@ -0,0 +1,9 @@ +namespace TrainSearchWorker.Models; + +public record SearchTrainsRequest +{ + public required string From { get; init; } + public required string To { get; init; } + public required string OutboundTime { get; init; } + public required string ReturnTime { get; init; } +} diff --git a/enterprise/Program.cs b/enterprise/Program.cs new file mode 100644 index 0000000..51f03db --- /dev/null +++ b/enterprise/Program.cs @@ -0,0 +1,52 @@ +using Microsoft.Extensions.DependencyInjection; +using Temporalio.Client; +using Temporalio.Worker; +using TrainSearchWorker.Activities; + +var services = new ServiceCollection(); + +// Add HTTP client +services.AddHttpClient("TrainApi", client => +{ + client.BaseAddress = new Uri("http://localhost:8080/"); + client.DefaultRequestHeaders.Add("Accept", "application/json"); +}); + +// Add activities +services.AddScoped(); + +var serviceProvider = services.BuildServiceProvider(); + +// Create client +var client = await TemporalClient.ConnectAsync(new() +{ + TargetHost = "localhost:7233", +}); + +// Create worker options +var options = new TemporalWorkerOptions("agent-task-queue-legacy"); + +// Register activities +var activities = serviceProvider.GetRequiredService(); +options.AddActivity(activities.SearchTrains); +options.AddActivity(activities.BookTrains); + +// Create and run worker +var worker = new TemporalWorker(client, options); + +Console.WriteLine("Starting worker..."); +using var tokenSource = new CancellationTokenSource(); +Console.CancelKeyPress += (_, eventArgs) => +{ + eventArgs.Cancel = true; + tokenSource.Cancel(); +}; + +try +{ + await worker.ExecuteAsync(tokenSource.Token); +} +catch (OperationCanceledException) +{ + Console.WriteLine("Worker shutting down..."); +} \ No newline at end of file diff --git a/enterprise/TrainSearchWorker.csproj b/enterprise/TrainSearchWorker.csproj new file mode 100644 index 0000000..e18bff4 --- /dev/null +++ b/enterprise/TrainSearchWorker.csproj @@ -0,0 +1,13 @@ + + + Exe + net8.0 + enable + enable + + + + + + + diff --git a/thirdparty/train_api.py b/thirdparty/train_api.py index 1c40e45..7fd6a13 100644 --- a/thirdparty/train_api.py +++ b/thirdparty/train_api.py @@ -151,7 +151,7 @@ class TrainServer(BaseHTTPRequestHandler): def do_GET(self): parsed_url = urlparse(self.path) - if parsed_url.path == "/api/journeys": + if parsed_url.path == "/api/search": try: params = parse_qs(parsed_url.query) origin = params.get("from", [""])[0] @@ -212,7 +212,7 @@ class TrainServer(BaseHTTPRequestHandler): parsed_url = urlparse(self.path) if parsed_url.path.startswith("/api/book/"): - journey_id = parsed_url.path.split("/")[-1] + train_ids = parsed_url.path.split("/")[-1].split(",") self.send_response(200) self.send_header("Content-Type", "application/json") self.end_headers() @@ -224,7 +224,7 @@ class TrainServer(BaseHTTPRequestHandler): response = self.format_json( { "booking_reference": booking_ref, - "journey_id": journey_id, + "train_ids": train_ids, "status": "confirmed", } ) diff --git a/tools/__init__.py b/tools/__init__.py index f7772a9..93aa8ef 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -1,7 +1,7 @@ from .search_fixtures import search_fixtures from .search_flights import search_flights from .search_trains import search_trains -from .search_trains import book_train +from .search_trains import book_trains from .create_invoice import create_invoice @@ -12,8 +12,8 @@ def get_handler(tool_name: str): return search_flights if tool_name == "SearchTrains": return search_trains - if tool_name == "BookTrain": - return book_train + if tool_name == "BookTrains": + return book_trains if tool_name == "CreateInvoice": return create_invoice diff --git a/tools/goal_registry.py b/tools/goal_registry.py index 55f363d..858589d 100644 --- a/tools/goal_registry.py +++ b/tools/goal_registry.py @@ -3,7 +3,7 @@ from tools.tool_registry import ( search_fixtures_tool, search_flights_tool, search_trains_tool, - book_train_tool, + book_trains_tool, create_invoice_tool, ) @@ -11,13 +11,13 @@ goal_match_train_invoice = AgentGoal( tools=[ search_fixtures_tool, search_trains_tool, - book_train_tool, + book_trains_tool, create_invoice_tool, ], description="Help the user book trains to a premier league match. The user lives in London. Gather args for these tools in order: " "1. SearchFixtures: Search for fixtures for a team in a given month" "2. SearchTrains: Search for trains to the city of the match and list them for the customer to choose from" - "3. BookTrain: Book the train tickets" + "3. BookTrains: Book the train tickets" "4. CreateInvoice: Proactively offer to create a simple invoice for the cost of the flights and train tickets", starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job", example_conversation_history="\n ".join( @@ -32,7 +32,7 @@ goal_match_train_invoice = AgentGoal( "user_confirmed_tool_run: ", "tool_result: ", "agent: Found some trains! ", - "user_confirmed_tool_run: ", + "user_confirmed_tool_run: ", 'tool_result: results including {"status": "success"}', "agent: Train tickets booked! Please confirm the following invoice for the journey. ", "user_confirmed_tool_run: ", diff --git a/tools/search_trains.py b/tools/search_trains.py index 5ca5578..55c4745 100644 --- a/tools/search_trains.py +++ b/tools/search_trains.py @@ -15,7 +15,7 @@ def search_trains(args: dict) -> dict: if not origin or not destination or not outbound_time or not return_time: return {"error": "Origin, destination, outbound_time and return_time are required."} - search_url = f'{BASE_URL}/api/journeys' + search_url = f'{BASE_URL}/api/search' params = { 'from': origin, 'to': destination, @@ -31,15 +31,15 @@ def search_trains(args: dict) -> dict: journey_data = response.json() return journey_data -def book_train(args: dict) -> dict: +def book_trains(args: dict) -> dict: load_dotenv(override=True) - journey_id = args.get("journey_id") + train_ids = args.get("train_ids") - if not journey_id: - return {"error": "Journey ID is required."} + if not train_ids: + return {"error": "Train IDs is required."} - book_url = f'{BASE_URL}/api/book/{journey_id}' + book_url = f'{BASE_URL}/api/book/{train_ids}' response = requests.post(book_url) if response.status_code != 200: return {"error": "Failed to book ticket."} @@ -59,7 +59,7 @@ if __name__ == "__main__": print(search_results) book_args = { - "journey_id": "12345", + "train_ids": "12345", } - booking_results = book_train(book_args) + booking_results = book_trains(book_args) print(booking_results) \ No newline at end of file diff --git a/tools/tool_registry.py b/tools/tool_registry.py index 3480bd1..27f9d45 100644 --- a/tools/tool_registry.py +++ b/tools/tool_registry.py @@ -54,14 +54,14 @@ search_trains_tool = ToolDefinition( ], ) -book_train_tool = ToolDefinition( - name="BookTrain", - description="Books a train ticket. Returns a booking reference.", +book_trains_tool = ToolDefinition( + name="BookTrains", + description="Books train tickets. Returns a booking reference.", arguments=[ ToolArgument( - name="journey_id", + name="train_ids", type="string", - description="The ID of the journey to book", + description="The IDs of the trains to book, comma separated", ), ], )