diff --git a/enterprise/Activities/TrainActivities.cs b/enterprise/Activities/TrainActivities.cs index 40de38e..e4aa88f 100644 --- a/enterprise/Activities/TrainActivities.cs +++ b/enterprise/Activities/TrainActivities.cs @@ -1,20 +1,28 @@ using System.Net.Http.Json; +using System.Text.Json; using Temporalio.Activities; using TrainSearchWorker.Models; +using TrainSearchWorker.Converters; namespace TrainSearchWorker.Activities; public class TrainActivities { private readonly HttpClient _client; + private readonly JsonSerializerOptions _jsonOptions; public TrainActivities(IHttpClientFactory clientFactory) { _client = clientFactory.CreateClient("TrainApi"); + _jsonOptions = new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }; } + [Activity] - public async Task> SearchTrains(SearchTrainsRequest request) + public async Task SearchTrains(SearchTrainsRequest request) { var response = await _client.GetAsync( $"api/search?from={Uri.EscapeDataString(request.From)}" + @@ -24,17 +32,28 @@ public class TrainActivities response.EnsureSuccessStatusCode(); - return await response.Content.ReadFromJsonAsync>() - ?? throw new InvalidOperationException("Received null response from API"); + // Deserialize into JourneyResponse rather than List + var journeyResponse = await response.Content.ReadFromJsonAsync(_jsonOptions) + ?? throw new InvalidOperationException("Received null response from API"); + + return journeyResponse; } [Activity] - public async Task> BookTrains(BookTrainsRequest request) + public async Task BookTrains(BookTrainsRequest request) { - var response = await _client.PostAsJsonAsync("api/book", request); + // Build the URL using the train IDs from the request + var url = $"api/book/{Uri.EscapeDataString(request.TrainIds)}"; + + // POST with no JSON body, matching the Python version + var response = await _client.PostAsync(url, null); response.EnsureSuccessStatusCode(); - return await response.Content.ReadFromJsonAsync>() - ?? throw new InvalidOperationException("Received null response from API"); + // Deserialize into a BookTrainsResponse (a single object) + var bookingResponse = await response.Content.ReadFromJsonAsync(_jsonOptions) + ?? throw new InvalidOperationException("Received null response from API"); + + return bookingResponse; } + } diff --git a/enterprise/Models/BookTrainRequest.cs b/enterprise/Models/BookTrainRequest.cs index a89f966..cf6b1fb 100644 --- a/enterprise/Models/BookTrainRequest.cs +++ b/enterprise/Models/BookTrainRequest.cs @@ -1,6 +1,9 @@ +using System.Text.Json.Serialization; + namespace TrainSearchWorker.Models; public record BookTrainsRequest { + [JsonPropertyName("train_ids")] public required string TrainIds { get; init; } } diff --git a/enterprise/Models/BookTrainsResponse.cs b/enterprise/Models/BookTrainsResponse.cs new file mode 100644 index 0000000..1059552 --- /dev/null +++ b/enterprise/Models/BookTrainsResponse.cs @@ -0,0 +1,17 @@ +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace TrainSearchWorker.Models; + +public record BookTrainsResponse +{ + [JsonPropertyName("booking_reference")] + public required string BookingReference { get; init; } + + // If the API now returns train_ids as an array, use List + [JsonPropertyName("train_ids")] + public required List TrainIds { get; init; } + + [JsonPropertyName("status")] + public required string Status { get; init; } +} diff --git a/enterprise/Models/Journey.cs b/enterprise/Models/Journey.cs index 1038abd..cce3798 100644 --- a/enterprise/Models/Journey.cs +++ b/enterprise/Models/Journey.cs @@ -1,12 +1,27 @@ +using System.Text.Json.Serialization; + namespace TrainSearchWorker.Models; public record Journey { + [JsonPropertyName("id")] public required string Id { get; init; } + + [JsonPropertyName("type")] public required string Type { get; init; } + + [JsonPropertyName("departure")] public required string Departure { get; init; } + + [JsonPropertyName("arrival")] public required string Arrival { get; init; } + + [JsonPropertyName("departure_time")] public required string DepartureTime { get; init; } + + [JsonPropertyName("arrival_time")] public required string ArrivalTime { get; init; } + + [JsonPropertyName("price")] public required decimal Price { get; init; } -} +} \ No newline at end of file diff --git a/enterprise/Models/JourneyResponse.cs b/enterprise/Models/JourneyResponse.cs new file mode 100644 index 0000000..f762b08 --- /dev/null +++ b/enterprise/Models/JourneyResponse.cs @@ -0,0 +1,10 @@ +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace TrainSearchWorker.Models; + +public record JourneyResponse +{ + [JsonPropertyName("journeys")] + public List? Journeys { get; init; } +} diff --git a/enterprise/Models/SearchTrainsRequest.cs b/enterprise/Models/SearchTrainsRequest.cs index a59a21d..719d412 100644 --- a/enterprise/Models/SearchTrainsRequest.cs +++ b/enterprise/Models/SearchTrainsRequest.cs @@ -1,9 +1,18 @@ +using System.Text.Json.Serialization; + namespace TrainSearchWorker.Models; public record SearchTrainsRequest { + [JsonPropertyName("origin")] public required string From { get; init; } + + [JsonPropertyName("destination")] public required string To { get; init; } + + [JsonPropertyName("outbound_time")] public required string OutboundTime { get; init; } + + [JsonPropertyName("return_time")] public required string ReturnTime { get; init; } } diff --git a/enterprise/Program.cs b/enterprise/Program.cs index 51f03db..c8fbaa5 100644 --- a/enterprise/Program.cs +++ b/enterprise/Program.cs @@ -3,6 +3,7 @@ using Temporalio.Client; using Temporalio.Worker; using TrainSearchWorker.Activities; +// Set up dependency injection var services = new ServiceCollection(); // Add HTTP client @@ -17,11 +18,17 @@ services.AddScoped(); var serviceProvider = services.BuildServiceProvider(); -// Create client -var client = await TemporalClient.ConnectAsync(new() -{ - TargetHost = "localhost:7233", -}); +// Create client using the helper, which supports Temporal Cloud if environment variables are set +var client = await TemporalClientHelper.CreateClientAsync(); + +// Read connection details from environment or use defaults +var address = Environment.GetEnvironmentVariable("TEMPORAL_ADDRESS") ?? "localhost:7233"; +var ns = Environment.GetEnvironmentVariable("TEMPORAL_NAMESPACE") ?? "default"; + +// Log connection details +Console.WriteLine("Starting worker..."); +Console.WriteLine($"Connecting to Temporal at address: {address}"); +Console.WriteLine($"Using namespace: {ns}"); // Create worker options var options = new TemporalWorkerOptions("agent-task-queue-legacy"); @@ -34,7 +41,6 @@ options.AddActivity(activities.BookTrains); // Create and run worker var worker = new TemporalWorker(client, options); -Console.WriteLine("Starting worker..."); using var tokenSource = new CancellationTokenSource(); Console.CancelKeyPress += (_, eventArgs) => { @@ -49,4 +55,4 @@ try catch (OperationCanceledException) { Console.WriteLine("Worker shutting down..."); -} \ No newline at end of file +} diff --git a/enterprise/SingleOrArrayConverter.cs b/enterprise/SingleOrArrayConverter.cs new file mode 100644 index 0000000..f98f27e --- /dev/null +++ b/enterprise/SingleOrArrayConverter.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace TrainSearchWorker.Converters +{ + public class SingleOrArrayConverter : JsonConverter> + { + public override List Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.StartArray) + { + return JsonSerializer.Deserialize>(ref reader, options) ?? new List(); + } + else + { + // Single element – wrap it in a list. + T element = JsonSerializer.Deserialize(ref reader, options); + return new List { element }; + } + } + + public override void Write(Utf8JsonWriter writer, List value, JsonSerializerOptions options) + { + JsonSerializer.Serialize(writer, value, options); + } + } +} diff --git a/enterprise/TemporalClientHelper.cs b/enterprise/TemporalClientHelper.cs new file mode 100644 index 0000000..95e841f --- /dev/null +++ b/enterprise/TemporalClientHelper.cs @@ -0,0 +1,48 @@ +using System; +using System.IO; +using System.Collections.Generic; +using Temporalio.Client; + +public static class TemporalClientHelper +{ + public static async Task CreateClientAsync() + { + var address = Environment.GetEnvironmentVariable("TEMPORAL_ADDRESS") ?? "localhost:7233"; + var ns = Environment.GetEnvironmentVariable("TEMPORAL_NAMESPACE") ?? "default"; + var clientCertPath = Environment.GetEnvironmentVariable("TEMPORAL_TLS_CERT"); + var clientKeyPath = Environment.GetEnvironmentVariable("TEMPORAL_TLS_KEY"); + var apiKey = Environment.GetEnvironmentVariable("TEMPORAL_API_KEY"); + + var options = new TemporalClientConnectOptions(address) + { + Namespace = ns + }; + + if (!string.IsNullOrEmpty(clientCertPath) && !string.IsNullOrEmpty(clientKeyPath)) + { + // mTLS authentication + options.Tls = new() + { + ClientCert = await File.ReadAllBytesAsync(clientCertPath), + ClientPrivateKey = await File.ReadAllBytesAsync(clientKeyPath), + }; + } + else if (!string.IsNullOrEmpty(apiKey)) + { + // API Key authentication + // TODO test + options.RpcMetadata = new Dictionary() + { + ["authorization"] = $"Bearer {apiKey}", + ["temporal-namespace"] = ns + }; + options.RpcMetadata = new Dictionary() + { + ["temporal-namespace"] = ns + }; + options.Tls = new(); + } + + return await TemporalClient.ConnectAsync(options); + } +} \ No newline at end of file diff --git a/thirdparty/train_api.py b/thirdparty/train_api.py index 7fd6a13..e431e88 100644 --- a/thirdparty/train_api.py +++ b/thirdparty/train_api.py @@ -197,6 +197,7 @@ class TrainServer(BaseHTTPRequestHandler): journeys = self.generate_journeys(origin, destination, out_dt, ret_dt) response = self.format_json({"journeys": journeys}) + self.write_response(response) except Exception as e: @@ -228,6 +229,7 @@ class TrainServer(BaseHTTPRequestHandler): "status": "confirmed", } ) + self.write_response(response) else: diff --git a/tools/search_trains.py b/tools/search_trains.py index 55c4745..b345dba 100644 --- a/tools/search_trains.py +++ b/tools/search_trains.py @@ -2,50 +2,34 @@ import requests import os from dotenv import load_dotenv -BASE_URL = 'http://localhost:8080/' def search_trains(args: dict) -> dict: - load_dotenv(override=True) + raise NotImplementedError("TODO implement :)") - origin = args.get("origin") - destination = args.get("destination") - outbound_time = args.get("outbound_time") - return_time = args.get("return_time") - - if not origin or not destination or not outbound_time or not return_time: - return {"error": "Origin, destination, outbound_time and return_time are required."} - - search_url = f'{BASE_URL}/api/search' - params = { - 'from': origin, - 'to': destination, - 'outbound_time': outbound_time, - 'return_time': return_time, - } - - response = requests.get(search_url, params=params) - if response.status_code != 200: - print(response.content) - return {"error": "Failed to fetch journey data."} - - journey_data = response.json() - return journey_data def book_trains(args: dict) -> dict: - load_dotenv(override=True) + raise NotImplementedError("TODO implement :)") - train_ids = args.get("train_ids") - if not train_ids: - return {"error": "Train IDs is required."} +# todo clean this up +# BASE_URL = "http://localhost:8080/" - book_url = f'{BASE_URL}/api/book/{train_ids}' - response = requests.post(book_url) - if response.status_code != 200: - return {"error": "Failed to book ticket."} +# def book_trains(args: dict) -> dict: +# load_dotenv(override=True) + +# train_ids = args.get("train_ids") + +# if not train_ids: +# return {"error": "Train IDs is required."} + +# book_url = f"{BASE_URL}/api/book/{train_ids}" +# response = requests.post(book_url) +# if response.status_code != 200: +# return {"error": "Failed to book ticket."} + +# booking_data = response.json() +# return booking_data - booking_data = response.json() - return booking_data # Example usage if __name__ == "__main__": @@ -53,7 +37,7 @@ if __name__ == "__main__": "origin": "London Gatwick", "destination": "Manchester", "outbound_time": "2025-03-15T14:00", - "return_time": "2025-03-20T14:00" + "return_time": "2025-03-20T14:00", } search_results = search_trains(search_args) print(search_results) @@ -62,4 +46,32 @@ if __name__ == "__main__": "train_ids": "12345", } booking_results = book_trains(book_args) - print(booking_results) \ No newline at end of file + print(booking_results) + + +# def search_trains(args: dict) -> dict: +# load_dotenv(override=True) + +# origin = args.get("origin") +# destination = args.get("destination") +# outbound_time = args.get("outbound_time") +# return_time = args.get("return_time") + +# if not origin or not destination or not outbound_time or not return_time: +# return {"error": "Origin, destination, outbound_time and return_time are required."} + +# search_url = f'{BASE_URL}/api/search' +# params = { +# 'from': origin, +# 'to': destination, +# 'outbound_time': outbound_time, +# 'return_time': return_time, +# } + +# response = requests.get(search_url, params=params) +# if response.status_code != 200: +# print(response.content) +# return {"error": "Failed to fetch journey data."} + +# journey_data = response.json() +# return journey_data diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index 97893e3..c12cad5 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -49,7 +49,11 @@ class ToolWorkflow: """Execute a tool after confirmation and handle its result.""" workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}") - task_queue = TEMPORAL_LEGACY_TASK_QUEUE if current_tool in ["SearchTrain", "BookTrain"] else None + task_queue = ( + TEMPORAL_LEGACY_TASK_QUEUE + if current_tool in ["SearchTrains", "BookTrains"] + else None + ) try: dynamic_result = await workflow.execute_activity(