tcloud compatibility, dotnet bug fixes

This commit is contained in:
Steve Androulakis
2025-02-14 11:10:16 -08:00
parent a9d8a2a631
commit d6bbb900b7
12 changed files with 227 additions and 53 deletions

View File

@@ -1,20 +1,28 @@
using System.Net.Http.Json;
using System.Text.Json;
using Temporalio.Activities;
using TrainSearchWorker.Models;
using TrainSearchWorker.Converters;
namespace TrainSearchWorker.Activities;
public class TrainActivities
{
private readonly HttpClient _client;
private readonly JsonSerializerOptions _jsonOptions;
public TrainActivities(IHttpClientFactory clientFactory)
{
_client = clientFactory.CreateClient("TrainApi");
_jsonOptions = new JsonSerializerOptions
{
PropertyNameCaseInsensitive = true
};
}
[Activity]
public async Task<List<Journey>> SearchTrains(SearchTrainsRequest request)
public async Task<JourneyResponse> SearchTrains(SearchTrainsRequest request)
{
var response = await _client.GetAsync(
$"api/search?from={Uri.EscapeDataString(request.From)}" +
@@ -24,17 +32,28 @@ public class TrainActivities
response.EnsureSuccessStatusCode();
return await response.Content.ReadFromJsonAsync<List<Journey>>()
?? throw new InvalidOperationException("Received null response from API");
// Deserialize into JourneyResponse rather than List<Journey>
var journeyResponse = await response.Content.ReadFromJsonAsync<JourneyResponse>(_jsonOptions)
?? throw new InvalidOperationException("Received null response from API");
return journeyResponse;
}
[Activity]
public async Task<List<Journey>> BookTrains(BookTrainsRequest request)
public async Task<BookTrainsResponse> BookTrains(BookTrainsRequest request)
{
var response = await _client.PostAsJsonAsync("api/book", request);
// Build the URL using the train IDs from the request
var url = $"api/book/{Uri.EscapeDataString(request.TrainIds)}";
// POST with no JSON body, matching the Python version
var response = await _client.PostAsync(url, null);
response.EnsureSuccessStatusCode();
return await response.Content.ReadFromJsonAsync<List<Journey>>()
?? throw new InvalidOperationException("Received null response from API");
// Deserialize into a BookTrainsResponse (a single object)
var bookingResponse = await response.Content.ReadFromJsonAsync<BookTrainsResponse>(_jsonOptions)
?? throw new InvalidOperationException("Received null response from API");
return bookingResponse;
}
}

View File

@@ -1,6 +1,9 @@
using System.Text.Json.Serialization;
namespace TrainSearchWorker.Models;
public record BookTrainsRequest
{
[JsonPropertyName("train_ids")]
public required string TrainIds { get; init; }
}

View File

@@ -0,0 +1,17 @@
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace TrainSearchWorker.Models;
public record BookTrainsResponse
{
[JsonPropertyName("booking_reference")]
public required string BookingReference { get; init; }
// If the API now returns train_ids as an array, use List<string>
[JsonPropertyName("train_ids")]
public required List<string> TrainIds { get; init; }
[JsonPropertyName("status")]
public required string Status { get; init; }
}

View File

@@ -1,12 +1,27 @@
using System.Text.Json.Serialization;
namespace TrainSearchWorker.Models;
public record Journey
{
[JsonPropertyName("id")]
public required string Id { get; init; }
[JsonPropertyName("type")]
public required string Type { get; init; }
[JsonPropertyName("departure")]
public required string Departure { get; init; }
[JsonPropertyName("arrival")]
public required string Arrival { get; init; }
[JsonPropertyName("departure_time")]
public required string DepartureTime { get; init; }
[JsonPropertyName("arrival_time")]
public required string ArrivalTime { get; init; }
[JsonPropertyName("price")]
public required decimal Price { get; init; }
}

View File

@@ -0,0 +1,10 @@
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace TrainSearchWorker.Models;
public record JourneyResponse
{
[JsonPropertyName("journeys")]
public List<Journey>? Journeys { get; init; }
}

View File

@@ -1,9 +1,18 @@
using System.Text.Json.Serialization;
namespace TrainSearchWorker.Models;
public record SearchTrainsRequest
{
[JsonPropertyName("origin")]
public required string From { get; init; }
[JsonPropertyName("destination")]
public required string To { get; init; }
[JsonPropertyName("outbound_time")]
public required string OutboundTime { get; init; }
[JsonPropertyName("return_time")]
public required string ReturnTime { get; init; }
}

View File

@@ -3,6 +3,7 @@ using Temporalio.Client;
using Temporalio.Worker;
using TrainSearchWorker.Activities;
// Set up dependency injection
var services = new ServiceCollection();
// Add HTTP client
@@ -17,11 +18,17 @@ services.AddScoped<TrainActivities>();
var serviceProvider = services.BuildServiceProvider();
// Create client
var client = await TemporalClient.ConnectAsync(new()
{
TargetHost = "localhost:7233",
});
// Create client using the helper, which supports Temporal Cloud if environment variables are set
var client = await TemporalClientHelper.CreateClientAsync();
// Read connection details from environment or use defaults
var address = Environment.GetEnvironmentVariable("TEMPORAL_ADDRESS") ?? "localhost:7233";
var ns = Environment.GetEnvironmentVariable("TEMPORAL_NAMESPACE") ?? "default";
// Log connection details
Console.WriteLine("Starting worker...");
Console.WriteLine($"Connecting to Temporal at address: {address}");
Console.WriteLine($"Using namespace: {ns}");
// Create worker options
var options = new TemporalWorkerOptions("agent-task-queue-legacy");
@@ -34,7 +41,6 @@ options.AddActivity(activities.BookTrains);
// Create and run worker
var worker = new TemporalWorker(client, options);
Console.WriteLine("Starting worker...");
using var tokenSource = new CancellationTokenSource();
Console.CancelKeyPress += (_, eventArgs) =>
{

View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace TrainSearchWorker.Converters
{
public class SingleOrArrayConverter<T> : JsonConverter<List<T>>
{
public override List<T> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
if (reader.TokenType == JsonTokenType.StartArray)
{
return JsonSerializer.Deserialize<List<T>>(ref reader, options) ?? new List<T>();
}
else
{
// Single element wrap it in a list.
T element = JsonSerializer.Deserialize<T>(ref reader, options);
return new List<T> { element };
}
}
public override void Write(Utf8JsonWriter writer, List<T> value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value, options);
}
}
}

View File

@@ -0,0 +1,48 @@
using System;
using System.IO;
using System.Collections.Generic;
using Temporalio.Client;
public static class TemporalClientHelper
{
public static async Task<ITemporalClient> CreateClientAsync()
{
var address = Environment.GetEnvironmentVariable("TEMPORAL_ADDRESS") ?? "localhost:7233";
var ns = Environment.GetEnvironmentVariable("TEMPORAL_NAMESPACE") ?? "default";
var clientCertPath = Environment.GetEnvironmentVariable("TEMPORAL_TLS_CERT");
var clientKeyPath = Environment.GetEnvironmentVariable("TEMPORAL_TLS_KEY");
var apiKey = Environment.GetEnvironmentVariable("TEMPORAL_API_KEY");
var options = new TemporalClientConnectOptions(address)
{
Namespace = ns
};
if (!string.IsNullOrEmpty(clientCertPath) && !string.IsNullOrEmpty(clientKeyPath))
{
// mTLS authentication
options.Tls = new()
{
ClientCert = await File.ReadAllBytesAsync(clientCertPath),
ClientPrivateKey = await File.ReadAllBytesAsync(clientKeyPath),
};
}
else if (!string.IsNullOrEmpty(apiKey))
{
// API Key authentication
// TODO test
options.RpcMetadata = new Dictionary<string, string>()
{
["authorization"] = $"Bearer {apiKey}",
["temporal-namespace"] = ns
};
options.RpcMetadata = new Dictionary<string, string>()
{
["temporal-namespace"] = ns
};
options.Tls = new();
}
return await TemporalClient.ConnectAsync(options);
}
}

View File

@@ -197,6 +197,7 @@ class TrainServer(BaseHTTPRequestHandler):
journeys = self.generate_journeys(origin, destination, out_dt, ret_dt)
response = self.format_json({"journeys": journeys})
self.write_response(response)
except Exception as e:
@@ -228,6 +229,7 @@ class TrainServer(BaseHTTPRequestHandler):
"status": "confirmed",
}
)
self.write_response(response)
else:

View File

@@ -2,50 +2,34 @@ import requests
import os
from dotenv import load_dotenv
BASE_URL = 'http://localhost:8080/'
def search_trains(args: dict) -> dict:
load_dotenv(override=True)
raise NotImplementedError("TODO implement :)")
origin = args.get("origin")
destination = args.get("destination")
outbound_time = args.get("outbound_time")
return_time = args.get("return_time")
if not origin or not destination or not outbound_time or not return_time:
return {"error": "Origin, destination, outbound_time and return_time are required."}
search_url = f'{BASE_URL}/api/search'
params = {
'from': origin,
'to': destination,
'outbound_time': outbound_time,
'return_time': return_time,
}
response = requests.get(search_url, params=params)
if response.status_code != 200:
print(response.content)
return {"error": "Failed to fetch journey data."}
journey_data = response.json()
return journey_data
def book_trains(args: dict) -> dict:
load_dotenv(override=True)
raise NotImplementedError("TODO implement :)")
train_ids = args.get("train_ids")
if not train_ids:
return {"error": "Train IDs is required."}
# todo clean this up
# BASE_URL = "http://localhost:8080/"
book_url = f'{BASE_URL}/api/book/{train_ids}'
response = requests.post(book_url)
if response.status_code != 200:
return {"error": "Failed to book ticket."}
# def book_trains(args: dict) -> dict:
# load_dotenv(override=True)
# train_ids = args.get("train_ids")
# if not train_ids:
# return {"error": "Train IDs is required."}
# book_url = f"{BASE_URL}/api/book/{train_ids}"
# response = requests.post(book_url)
# if response.status_code != 200:
# return {"error": "Failed to book ticket."}
# booking_data = response.json()
# return booking_data
booking_data = response.json()
return booking_data
# Example usage
if __name__ == "__main__":
@@ -53,7 +37,7 @@ if __name__ == "__main__":
"origin": "London Gatwick",
"destination": "Manchester",
"outbound_time": "2025-03-15T14:00",
"return_time": "2025-03-20T14:00"
"return_time": "2025-03-20T14:00",
}
search_results = search_trains(search_args)
print(search_results)
@@ -63,3 +47,31 @@ if __name__ == "__main__":
}
booking_results = book_trains(book_args)
print(booking_results)
# def search_trains(args: dict) -> dict:
# load_dotenv(override=True)
# origin = args.get("origin")
# destination = args.get("destination")
# outbound_time = args.get("outbound_time")
# return_time = args.get("return_time")
# if not origin or not destination or not outbound_time or not return_time:
# return {"error": "Origin, destination, outbound_time and return_time are required."}
# search_url = f'{BASE_URL}/api/search'
# params = {
# 'from': origin,
# 'to': destination,
# 'outbound_time': outbound_time,
# 'return_time': return_time,
# }
# response = requests.get(search_url, params=params)
# if response.status_code != 200:
# print(response.content)
# return {"error": "Failed to fetch journey data."}
# journey_data = response.json()
# return journey_data

View File

@@ -49,7 +49,11 @@ class ToolWorkflow:
"""Execute a tool after confirmation and handle its result."""
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
task_queue = TEMPORAL_LEGACY_TASK_QUEUE if current_tool in ["SearchTrain", "BookTrain"] else None
task_queue = (
TEMPORAL_LEGACY_TASK_QUEUE
if current_tool in ["SearchTrains", "BookTrains"]
else None
)
try:
dynamic_result = await workflow.execute_activity(