mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 05:58:08 +01:00
tcloud compatibility, dotnet bug fixes
This commit is contained in:
@@ -1,20 +1,28 @@
|
|||||||
using System.Net.Http.Json;
|
using System.Net.Http.Json;
|
||||||
|
using System.Text.Json;
|
||||||
using Temporalio.Activities;
|
using Temporalio.Activities;
|
||||||
using TrainSearchWorker.Models;
|
using TrainSearchWorker.Models;
|
||||||
|
using TrainSearchWorker.Converters;
|
||||||
|
|
||||||
namespace TrainSearchWorker.Activities;
|
namespace TrainSearchWorker.Activities;
|
||||||
|
|
||||||
public class TrainActivities
|
public class TrainActivities
|
||||||
{
|
{
|
||||||
private readonly HttpClient _client;
|
private readonly HttpClient _client;
|
||||||
|
private readonly JsonSerializerOptions _jsonOptions;
|
||||||
|
|
||||||
public TrainActivities(IHttpClientFactory clientFactory)
|
public TrainActivities(IHttpClientFactory clientFactory)
|
||||||
{
|
{
|
||||||
_client = clientFactory.CreateClient("TrainApi");
|
_client = clientFactory.CreateClient("TrainApi");
|
||||||
|
_jsonOptions = new JsonSerializerOptions
|
||||||
|
{
|
||||||
|
PropertyNameCaseInsensitive = true
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
[Activity]
|
[Activity]
|
||||||
public async Task<List<Journey>> SearchTrains(SearchTrainsRequest request)
|
public async Task<JourneyResponse> SearchTrains(SearchTrainsRequest request)
|
||||||
{
|
{
|
||||||
var response = await _client.GetAsync(
|
var response = await _client.GetAsync(
|
||||||
$"api/search?from={Uri.EscapeDataString(request.From)}" +
|
$"api/search?from={Uri.EscapeDataString(request.From)}" +
|
||||||
@@ -24,17 +32,28 @@ public class TrainActivities
|
|||||||
|
|
||||||
response.EnsureSuccessStatusCode();
|
response.EnsureSuccessStatusCode();
|
||||||
|
|
||||||
return await response.Content.ReadFromJsonAsync<List<Journey>>()
|
// Deserialize into JourneyResponse rather than List<Journey>
|
||||||
|
var journeyResponse = await response.Content.ReadFromJsonAsync<JourneyResponse>(_jsonOptions)
|
||||||
?? throw new InvalidOperationException("Received null response from API");
|
?? throw new InvalidOperationException("Received null response from API");
|
||||||
|
|
||||||
|
return journeyResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
[Activity]
|
[Activity]
|
||||||
public async Task<List<Journey>> BookTrains(BookTrainsRequest request)
|
public async Task<BookTrainsResponse> BookTrains(BookTrainsRequest request)
|
||||||
{
|
{
|
||||||
var response = await _client.PostAsJsonAsync("api/book", request);
|
// Build the URL using the train IDs from the request
|
||||||
|
var url = $"api/book/{Uri.EscapeDataString(request.TrainIds)}";
|
||||||
|
|
||||||
|
// POST with no JSON body, matching the Python version
|
||||||
|
var response = await _client.PostAsync(url, null);
|
||||||
response.EnsureSuccessStatusCode();
|
response.EnsureSuccessStatusCode();
|
||||||
|
|
||||||
return await response.Content.ReadFromJsonAsync<List<Journey>>()
|
// Deserialize into a BookTrainsResponse (a single object)
|
||||||
|
var bookingResponse = await response.Content.ReadFromJsonAsync<BookTrainsResponse>(_jsonOptions)
|
||||||
?? throw new InvalidOperationException("Received null response from API");
|
?? throw new InvalidOperationException("Received null response from API");
|
||||||
|
|
||||||
|
return bookingResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
|
using System.Text.Json.Serialization;
|
||||||
|
|
||||||
namespace TrainSearchWorker.Models;
|
namespace TrainSearchWorker.Models;
|
||||||
|
|
||||||
public record BookTrainsRequest
|
public record BookTrainsRequest
|
||||||
{
|
{
|
||||||
|
[JsonPropertyName("train_ids")]
|
||||||
public required string TrainIds { get; init; }
|
public required string TrainIds { get; init; }
|
||||||
}
|
}
|
||||||
|
|||||||
17
enterprise/Models/BookTrainsResponse.cs
Normal file
17
enterprise/Models/BookTrainsResponse.cs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Text.Json.Serialization;
|
||||||
|
|
||||||
|
namespace TrainSearchWorker.Models;
|
||||||
|
|
||||||
|
public record BookTrainsResponse
|
||||||
|
{
|
||||||
|
[JsonPropertyName("booking_reference")]
|
||||||
|
public required string BookingReference { get; init; }
|
||||||
|
|
||||||
|
// If the API now returns train_ids as an array, use List<string>
|
||||||
|
[JsonPropertyName("train_ids")]
|
||||||
|
public required List<string> TrainIds { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("status")]
|
||||||
|
public required string Status { get; init; }
|
||||||
|
}
|
||||||
@@ -1,12 +1,27 @@
|
|||||||
|
using System.Text.Json.Serialization;
|
||||||
|
|
||||||
namespace TrainSearchWorker.Models;
|
namespace TrainSearchWorker.Models;
|
||||||
|
|
||||||
public record Journey
|
public record Journey
|
||||||
{
|
{
|
||||||
|
[JsonPropertyName("id")]
|
||||||
public required string Id { get; init; }
|
public required string Id { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("type")]
|
||||||
public required string Type { get; init; }
|
public required string Type { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("departure")]
|
||||||
public required string Departure { get; init; }
|
public required string Departure { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("arrival")]
|
||||||
public required string Arrival { get; init; }
|
public required string Arrival { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("departure_time")]
|
||||||
public required string DepartureTime { get; init; }
|
public required string DepartureTime { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("arrival_time")]
|
||||||
public required string ArrivalTime { get; init; }
|
public required string ArrivalTime { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("price")]
|
||||||
public required decimal Price { get; init; }
|
public required decimal Price { get; init; }
|
||||||
}
|
}
|
||||||
10
enterprise/Models/JourneyResponse.cs
Normal file
10
enterprise/Models/JourneyResponse.cs
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Text.Json.Serialization;
|
||||||
|
|
||||||
|
namespace TrainSearchWorker.Models;
|
||||||
|
|
||||||
|
public record JourneyResponse
|
||||||
|
{
|
||||||
|
[JsonPropertyName("journeys")]
|
||||||
|
public List<Journey>? Journeys { get; init; }
|
||||||
|
}
|
||||||
@@ -1,9 +1,18 @@
|
|||||||
|
using System.Text.Json.Serialization;
|
||||||
|
|
||||||
namespace TrainSearchWorker.Models;
|
namespace TrainSearchWorker.Models;
|
||||||
|
|
||||||
public record SearchTrainsRequest
|
public record SearchTrainsRequest
|
||||||
{
|
{
|
||||||
|
[JsonPropertyName("origin")]
|
||||||
public required string From { get; init; }
|
public required string From { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("destination")]
|
||||||
public required string To { get; init; }
|
public required string To { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("outbound_time")]
|
||||||
public required string OutboundTime { get; init; }
|
public required string OutboundTime { get; init; }
|
||||||
|
|
||||||
|
[JsonPropertyName("return_time")]
|
||||||
public required string ReturnTime { get; init; }
|
public required string ReturnTime { get; init; }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ using Temporalio.Client;
|
|||||||
using Temporalio.Worker;
|
using Temporalio.Worker;
|
||||||
using TrainSearchWorker.Activities;
|
using TrainSearchWorker.Activities;
|
||||||
|
|
||||||
|
// Set up dependency injection
|
||||||
var services = new ServiceCollection();
|
var services = new ServiceCollection();
|
||||||
|
|
||||||
// Add HTTP client
|
// Add HTTP client
|
||||||
@@ -17,11 +18,17 @@ services.AddScoped<TrainActivities>();
|
|||||||
|
|
||||||
var serviceProvider = services.BuildServiceProvider();
|
var serviceProvider = services.BuildServiceProvider();
|
||||||
|
|
||||||
// Create client
|
// Create client using the helper, which supports Temporal Cloud if environment variables are set
|
||||||
var client = await TemporalClient.ConnectAsync(new()
|
var client = await TemporalClientHelper.CreateClientAsync();
|
||||||
{
|
|
||||||
TargetHost = "localhost:7233",
|
// Read connection details from environment or use defaults
|
||||||
});
|
var address = Environment.GetEnvironmentVariable("TEMPORAL_ADDRESS") ?? "localhost:7233";
|
||||||
|
var ns = Environment.GetEnvironmentVariable("TEMPORAL_NAMESPACE") ?? "default";
|
||||||
|
|
||||||
|
// Log connection details
|
||||||
|
Console.WriteLine("Starting worker...");
|
||||||
|
Console.WriteLine($"Connecting to Temporal at address: {address}");
|
||||||
|
Console.WriteLine($"Using namespace: {ns}");
|
||||||
|
|
||||||
// Create worker options
|
// Create worker options
|
||||||
var options = new TemporalWorkerOptions("agent-task-queue-legacy");
|
var options = new TemporalWorkerOptions("agent-task-queue-legacy");
|
||||||
@@ -34,7 +41,6 @@ options.AddActivity(activities.BookTrains);
|
|||||||
// Create and run worker
|
// Create and run worker
|
||||||
var worker = new TemporalWorker(client, options);
|
var worker = new TemporalWorker(client, options);
|
||||||
|
|
||||||
Console.WriteLine("Starting worker...");
|
|
||||||
using var tokenSource = new CancellationTokenSource();
|
using var tokenSource = new CancellationTokenSource();
|
||||||
Console.CancelKeyPress += (_, eventArgs) =>
|
Console.CancelKeyPress += (_, eventArgs) =>
|
||||||
{
|
{
|
||||||
|
|||||||
29
enterprise/SingleOrArrayConverter.cs
Normal file
29
enterprise/SingleOrArrayConverter.cs
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Text.Json;
|
||||||
|
using System.Text.Json.Serialization;
|
||||||
|
|
||||||
|
namespace TrainSearchWorker.Converters
|
||||||
|
{
|
||||||
|
public class SingleOrArrayConverter<T> : JsonConverter<List<T>>
|
||||||
|
{
|
||||||
|
public override List<T> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||||
|
{
|
||||||
|
if (reader.TokenType == JsonTokenType.StartArray)
|
||||||
|
{
|
||||||
|
return JsonSerializer.Deserialize<List<T>>(ref reader, options) ?? new List<T>();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Single element – wrap it in a list.
|
||||||
|
T element = JsonSerializer.Deserialize<T>(ref reader, options);
|
||||||
|
return new List<T> { element };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public override void Write(Utf8JsonWriter writer, List<T> value, JsonSerializerOptions options)
|
||||||
|
{
|
||||||
|
JsonSerializer.Serialize(writer, value, options);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
48
enterprise/TemporalClientHelper.cs
Normal file
48
enterprise/TemporalClientHelper.cs
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
using System;
|
||||||
|
using System.IO;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using Temporalio.Client;
|
||||||
|
|
||||||
|
public static class TemporalClientHelper
|
||||||
|
{
|
||||||
|
public static async Task<ITemporalClient> CreateClientAsync()
|
||||||
|
{
|
||||||
|
var address = Environment.GetEnvironmentVariable("TEMPORAL_ADDRESS") ?? "localhost:7233";
|
||||||
|
var ns = Environment.GetEnvironmentVariable("TEMPORAL_NAMESPACE") ?? "default";
|
||||||
|
var clientCertPath = Environment.GetEnvironmentVariable("TEMPORAL_TLS_CERT");
|
||||||
|
var clientKeyPath = Environment.GetEnvironmentVariable("TEMPORAL_TLS_KEY");
|
||||||
|
var apiKey = Environment.GetEnvironmentVariable("TEMPORAL_API_KEY");
|
||||||
|
|
||||||
|
var options = new TemporalClientConnectOptions(address)
|
||||||
|
{
|
||||||
|
Namespace = ns
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!string.IsNullOrEmpty(clientCertPath) && !string.IsNullOrEmpty(clientKeyPath))
|
||||||
|
{
|
||||||
|
// mTLS authentication
|
||||||
|
options.Tls = new()
|
||||||
|
{
|
||||||
|
ClientCert = await File.ReadAllBytesAsync(clientCertPath),
|
||||||
|
ClientPrivateKey = await File.ReadAllBytesAsync(clientKeyPath),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
else if (!string.IsNullOrEmpty(apiKey))
|
||||||
|
{
|
||||||
|
// API Key authentication
|
||||||
|
// TODO test
|
||||||
|
options.RpcMetadata = new Dictionary<string, string>()
|
||||||
|
{
|
||||||
|
["authorization"] = $"Bearer {apiKey}",
|
||||||
|
["temporal-namespace"] = ns
|
||||||
|
};
|
||||||
|
options.RpcMetadata = new Dictionary<string, string>()
|
||||||
|
{
|
||||||
|
["temporal-namespace"] = ns
|
||||||
|
};
|
||||||
|
options.Tls = new();
|
||||||
|
}
|
||||||
|
|
||||||
|
return await TemporalClient.ConnectAsync(options);
|
||||||
|
}
|
||||||
|
}
|
||||||
2
thirdparty/train_api.py
vendored
2
thirdparty/train_api.py
vendored
@@ -197,6 +197,7 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
journeys = self.generate_journeys(origin, destination, out_dt, ret_dt)
|
journeys = self.generate_journeys(origin, destination, out_dt, ret_dt)
|
||||||
response = self.format_json({"journeys": journeys})
|
response = self.format_json({"journeys": journeys})
|
||||||
|
|
||||||
self.write_response(response)
|
self.write_response(response)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -228,6 +229,7 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
"status": "confirmed",
|
"status": "confirmed",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.write_response(response)
|
self.write_response(response)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,50 +2,34 @@ import requests
|
|||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
BASE_URL = 'http://localhost:8080/'
|
|
||||||
|
|
||||||
def search_trains(args: dict) -> dict:
|
def search_trains(args: dict) -> dict:
|
||||||
load_dotenv(override=True)
|
raise NotImplementedError("TODO implement :)")
|
||||||
|
|
||||||
origin = args.get("origin")
|
|
||||||
destination = args.get("destination")
|
|
||||||
outbound_time = args.get("outbound_time")
|
|
||||||
return_time = args.get("return_time")
|
|
||||||
|
|
||||||
if not origin or not destination or not outbound_time or not return_time:
|
|
||||||
return {"error": "Origin, destination, outbound_time and return_time are required."}
|
|
||||||
|
|
||||||
search_url = f'{BASE_URL}/api/search'
|
|
||||||
params = {
|
|
||||||
'from': origin,
|
|
||||||
'to': destination,
|
|
||||||
'outbound_time': outbound_time,
|
|
||||||
'return_time': return_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.get(search_url, params=params)
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(response.content)
|
|
||||||
return {"error": "Failed to fetch journey data."}
|
|
||||||
|
|
||||||
journey_data = response.json()
|
|
||||||
return journey_data
|
|
||||||
|
|
||||||
def book_trains(args: dict) -> dict:
|
def book_trains(args: dict) -> dict:
|
||||||
load_dotenv(override=True)
|
raise NotImplementedError("TODO implement :)")
|
||||||
|
|
||||||
train_ids = args.get("train_ids")
|
|
||||||
|
|
||||||
if not train_ids:
|
# todo clean this up
|
||||||
return {"error": "Train IDs is required."}
|
# BASE_URL = "http://localhost:8080/"
|
||||||
|
|
||||||
book_url = f'{BASE_URL}/api/book/{train_ids}'
|
# def book_trains(args: dict) -> dict:
|
||||||
response = requests.post(book_url)
|
# load_dotenv(override=True)
|
||||||
if response.status_code != 200:
|
|
||||||
return {"error": "Failed to book ticket."}
|
# train_ids = args.get("train_ids")
|
||||||
|
|
||||||
|
# if not train_ids:
|
||||||
|
# return {"error": "Train IDs is required."}
|
||||||
|
|
||||||
|
# book_url = f"{BASE_URL}/api/book/{train_ids}"
|
||||||
|
# response = requests.post(book_url)
|
||||||
|
# if response.status_code != 200:
|
||||||
|
# return {"error": "Failed to book ticket."}
|
||||||
|
|
||||||
|
# booking_data = response.json()
|
||||||
|
# return booking_data
|
||||||
|
|
||||||
booking_data = response.json()
|
|
||||||
return booking_data
|
|
||||||
|
|
||||||
# Example usage
|
# Example usage
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -53,7 +37,7 @@ if __name__ == "__main__":
|
|||||||
"origin": "London Gatwick",
|
"origin": "London Gatwick",
|
||||||
"destination": "Manchester",
|
"destination": "Manchester",
|
||||||
"outbound_time": "2025-03-15T14:00",
|
"outbound_time": "2025-03-15T14:00",
|
||||||
"return_time": "2025-03-20T14:00"
|
"return_time": "2025-03-20T14:00",
|
||||||
}
|
}
|
||||||
search_results = search_trains(search_args)
|
search_results = search_trains(search_args)
|
||||||
print(search_results)
|
print(search_results)
|
||||||
@@ -63,3 +47,31 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
booking_results = book_trains(book_args)
|
booking_results = book_trains(book_args)
|
||||||
print(booking_results)
|
print(booking_results)
|
||||||
|
|
||||||
|
|
||||||
|
# def search_trains(args: dict) -> dict:
|
||||||
|
# load_dotenv(override=True)
|
||||||
|
|
||||||
|
# origin = args.get("origin")
|
||||||
|
# destination = args.get("destination")
|
||||||
|
# outbound_time = args.get("outbound_time")
|
||||||
|
# return_time = args.get("return_time")
|
||||||
|
|
||||||
|
# if not origin or not destination or not outbound_time or not return_time:
|
||||||
|
# return {"error": "Origin, destination, outbound_time and return_time are required."}
|
||||||
|
|
||||||
|
# search_url = f'{BASE_URL}/api/search'
|
||||||
|
# params = {
|
||||||
|
# 'from': origin,
|
||||||
|
# 'to': destination,
|
||||||
|
# 'outbound_time': outbound_time,
|
||||||
|
# 'return_time': return_time,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# response = requests.get(search_url, params=params)
|
||||||
|
# if response.status_code != 200:
|
||||||
|
# print(response.content)
|
||||||
|
# return {"error": "Failed to fetch journey data."}
|
||||||
|
|
||||||
|
# journey_data = response.json()
|
||||||
|
# return journey_data
|
||||||
|
|||||||
@@ -49,7 +49,11 @@ class ToolWorkflow:
|
|||||||
"""Execute a tool after confirmation and handle its result."""
|
"""Execute a tool after confirmation and handle its result."""
|
||||||
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
||||||
|
|
||||||
task_queue = TEMPORAL_LEGACY_TASK_QUEUE if current_tool in ["SearchTrain", "BookTrain"] else None
|
task_queue = (
|
||||||
|
TEMPORAL_LEGACY_TASK_QUEUE
|
||||||
|
if current_tool in ["SearchTrains", "BookTrains"]
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dynamic_result = await workflow.execute_activity(
|
dynamic_result = await workflow.execute_activity(
|
||||||
|
|||||||
Reference in New Issue
Block a user