diff --git a/src/Microsoft.Tye.Hosting/HttpProxyService.cs b/src/Microsoft.Tye.Hosting/HttpProxyService.cs index e2053d32..794f338b 100644 --- a/src/Microsoft.Tye.Hosting/HttpProxyService.cs +++ b/src/Microsoft.Tye.Hosting/HttpProxyService.cs @@ -113,6 +113,8 @@ namespace Microsoft.Tye.Hosting builder.Configure(app => { + app.UseWebSockets(); + app.UseRouting(); app.UseEndpoints(endpointBuilder => diff --git a/test/E2ETest/TyeRunTests.cs b/test/E2ETest/TyeRunTests.cs index 13c244d8..22047301 100644 --- a/test/E2ETest/TyeRunTests.cs +++ b/test/E2ETest/TyeRunTests.cs @@ -11,10 +11,14 @@ using System.Net; using System.Net.Http; using System.Net.NetworkInformation; using System.Net.Sockets; +using System.Net.WebSockets; using System.Runtime.InteropServices; +using System.Text; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Primitives; using Microsoft.Tye; using Microsoft.Tye.Hosting; using Microsoft.Tye.Hosting.Model; @@ -669,6 +673,7 @@ services: }; var client = new HttpClient(new RetryHandler(handler)); + var wsClient = new ClientWebSocket(); await RunHostingApplication(application, new HostOptions(), async (app, uri) => { @@ -702,6 +707,34 @@ services: // checking preservePath behavior var responsePreservePath = await client.GetAsync(ingressUri + "/C/test"); Assert.Contains("Hit path /C/test", await responsePreservePath.Content.ReadAsStringAsync()); + + string GetWebSocketUri(string uri) + { + if (uri.StartsWith("http")) + { + return "ws" + uri.Substring(4); + } + else if (uri.StartsWith("https")) + { + return "wss" + uri.Substring(5); + } + + throw new NotSupportedException(); + } + + // Check the websocket endpoint + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + var wsUri = GetWebSocketUri(ingressUri); + await wsClient.ConnectAsync(new Uri(wsUri + "/A/ws"), cts.Token); + var data = Encoding.UTF8.GetBytes("Hello World"); + await wsClient.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cts.Token); + var receiveBuffer = new byte[4096]; + var result = await wsClient.ReceiveAsync(receiveBuffer.AsMemory(), cts.Token); + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal(data.Length, result.Count); + Assert.Equal(data, receiveBuffer.AsMemory(0, result.Count).ToArray()); + await wsClient.CloseAsync(WebSocketCloseStatus.NormalClosure, "", cts.Token); }); } diff --git a/test/E2ETest/testassets/projects/apps-with-ingress/ApplicationA/Startup.cs b/test/E2ETest/testassets/projects/apps-with-ingress/ApplicationA/Startup.cs index 0d0e171a..3aa0f20c 100644 --- a/test/E2ETest/testassets/projects/apps-with-ingress/ApplicationA/Startup.cs +++ b/test/E2ETest/testassets/projects/apps-with-ingress/ApplicationA/Startup.cs @@ -1,6 +1,8 @@ using System; using System.IO; +using System.Net.WebSockets; using System.Text.Json; +using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; @@ -25,6 +27,8 @@ namespace ApplicationA app.UseDeveloperExceptionPage(); } + app.UseWebSockets(); + app.UseRouting(); app.UseEndpoints(endpoints => @@ -53,6 +57,31 @@ namespace ApplicationA query })); }); + + endpoints.MapGet("/ws", async context => + { + if (!context.WebSockets.IsWebSocketRequest) + { + context.Response.StatusCode = 400; + } + else + { + using var ws = await context.WebSockets.AcceptWebSocketAsync(); + await Echo(ws); + } + }); + + async Task Echo(WebSocket webSocket) + { + var buffer = new byte[1024 * 4]; + var result = await webSocket.ReceiveAsync(buffer.AsMemory(), default); + while (result.MessageType != WebSocketMessageType.Close) + { + await webSocket.SendAsync(buffer.AsMemory(..result.Count), result.MessageType, result.EndOfMessage, default); + result = await webSocket.ReceiveAsync(buffer.AsMemory(), default); + } + await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", default); + } }); } }