tesses.webserver.websocket/Tesses.WebServer.WebSocket/Class1.cs

465 lines
16 KiB
C#

using System;
using Tesses.WebServer;
using System.Threading.Tasks;
using Tesses.WebServer.WebSocket;
using System.Threading;
using System.Text;
using Newtonsoft.Json;
using System.Security.Cryptography;
using System.Collections.Generic;
using System.Linq;
using System.Collections;
using System.IO;
namespace Tesses.WebServer
{
public abstract class EasyWebSocketServer
{
System.Timers.Timer timer;
Func<WebSocketMessage,Task> wsm;
Func<byte[],Task> ping;
bool canEnable;
public bool Enabled {
get{
if(!canEnable) return false;
return timer.Enabled;
}
set{
if(canEnable) timer.Enabled = value;
}}
internal async Task Opened(Func<WebSocketMessage,Task> sendWsm,Func<byte[],Task> ping,CancellationToken token)
{
this.wsm=sendWsm;
this.ping =ping;
timer=new System.Timers.Timer();
timer.Elapsed += async(sender,e)=>{
try{
await Ping();
}catch(Exception ex)
{
_=ex;
}
};
timer.Interval = 10000;
canEnable=true;
await OnConnectionStarted(token);
}
public abstract Task OnConnectionStarted(CancellationToken token);
public async Task Ping()
{
await Ping(new byte[]{ (byte)'P', (byte)'i', (byte)'n', (byte)'g' });
}
public async Task Ping(byte[] data)
{
await ping(data);
}
public abstract Task OnReceiveMessage(WebSocketMessage msg);
public async Task SendMessage(WebSocketMessage msg)
{
await wsm(msg);
}
public void Close(bool clean)
{
canEnable=false;
timer.Enabled=false;
timer.Dispose();
OnConnectionEnded(clean);
}
protected virtual void OnConnectionEnded(bool clean)
{
}
}
public static class WebSocketExtensions
{
internal static bool FirstEquals<T1,T2>(this Dictionary<T1,List<T2>> dict,T1 t,T2 t2)
{
T2 firstVal;
return dict.TryGetFirst(t,out firstVal) && firstVal.Equals(t2);
}
internal static bool AnyEquals<T1,T2>(this Dictionary<T1,List<T2>> dict,T1 t,T2 t2)
{
List<T2> items;
if(dict.TryGetValue(t,out items))
{
foreach(var item in items)
{
if(item.Equals(t2)) return true;
}
}
return false;
}
public static void StartWebSocketConnection(this ServerContext ctx,Action<Action<WebSocketMessage>,Action<byte[]>,CancellationToken> opened,Action<WebSocketMessage> arrived,Action<bool> closed)
{
var t=ctx.StartWebSocketConnectionAsync(async(s,p,c)=>await Task.Run(()=>opened(
(mm)=>{
Task.Run(async()=>await s(mm)).Wait();
},(data)=>{
Task.Run(async()=>await p(data)).Wait();
},c)),async(m)=>await Task.Run(()=>arrived(m)),closed);
Task.Run(()=>t).Wait();
}
public static async Task StartEasyWebSocketConnectionAsync(this ServerContext ctx,EasyWebSocketServer wss)
{
await ctx.StartWebSocketConnectionAsync(wss.Opened,wss.OnReceiveMessage,wss.Close);
}
public static async Task StartWebSocketConnectionAsync(this ServerContext ctx,Func<Func<WebSocketMessage,Task>,Func<byte[],Task>,CancellationToken,Task> opened,Func<WebSocketMessage,Task> arrived,Action<bool> closed)
{
WebSocketServer server=new WebSocketServer(ctx);
server.MessageArrived+=async(sender,e)=>{
try{
await arrived(e.Message);
}catch(Exception ex)
{
_=ex;
}
};
server.WebSocketClosed+=(sender,e)=>{
closed(e.Clean);
};
using(var cts=new CancellationTokenSource()){
Thread t=new Thread(async()=>{
try{
await opened(server.SendMessageAsync,server.Ping,cts.Token);
}catch(Exception ex)
{
_=ex;
}
});
t.Start();
await server.StartAsync();
cts.Cancel();
t.Join();
}
}
}
}
namespace Tesses.WebServer.WebSocket
{
public class WebSocketMessage
{
public static WebSocketMessage Create(string text)
{
WebSocketMessage msg=new WebSocketMessage();
msg.Text = text;
return msg;
}
public static WebSocketMessage Create(byte[] data)
{
WebSocketMessage msg=new WebSocketMessage();
msg.Data=data;
return msg;
}
public static WebSocketMessage Create(object data)
{
WebSocketMessage msg=new WebSocketMessage();
msg.EncodeJson(data);
return msg;
}
private WebSocketMessage()
{
Data=new byte[0];
}
internal WebSocketMessage(byte[] message,bool binary)
{
data=message;
Binary=binary;
}
private byte[] data;
public bool Binary {get;private set;}
public byte[] Data {get{return data;} private set{data=value; Binary=true;}}
public T DecodeJson<T>()
{
return JsonConvert.DeserializeObject<T>(Text);
}
private void EncodeJson(object data)
{
Text=JsonConvert.SerializeObject(data);
}
public string Text {get{return Encoding.UTF8.GetString(Data);} private set{data=Encoding.UTF8.GetBytes(value); Binary=false;}}
}
public class WebSocketMessageEventArgs : EventArgs
{
public WebSocketMessageEventArgs(WebSocketMessage message)
{
Message=message;
}
public WebSocketMessage Message {get;private set;}
}
public class WebSocketClosedEventArgs : EventArgs
{
public WebSocketClosedEventArgs(bool clean)
{
Clean=clean;
}
public bool Clean {get;private set;}
}
public class WebSocketServer
{
SemaphoreSlim semaphoreSlim=new SemaphoreSlim(1,1);
bool hasInit=false;
ServerContext context;
public WebSocketServer(ServerContext ctx)
{
context=ctx;
}
public EventHandler<WebSocketMessageEventArgs> MessageArrived;
public EventHandler<WebSocketClosedEventArgs> WebSocketClosed;
private byte[] glenBytes(long len)
{
if(len < 126)
{
return new byte[]{(byte)len};
}else if(len <= ushort.MaxValue)
{
byte[] num = BitConverter.GetBytes((ushort)len);
if(BitConverter.IsLittleEndian)
{
Array.Reverse(num);
}
return new byte[]{126,num[0],num[1]};
}else{
byte[] num = BitConverter.GetBytes(len);
if(BitConverter.IsLittleEndian)
{
Array.Reverse(num);
}
return new byte[]{127,num[0],num[1],num[2],num[3],num[4],num[5],num[6],num[7]};
}
}
public async Task SendMessageAsync(WebSocketMessage msg)
{
while(!hasInit) ;
await semaphoreSlim.WaitAsync();
int opCode = msg.Binary ? 0x2 : 0x1;
int dataBytes = msg.Data.Length;
int lengthlastByte = dataBytes % 4096;
int noPackets = (int)Math.Ceiling(dataBytes / 4096.0);
for(int i = 0;i<noPackets;i++)
{
bool fin = i == noPackets-1;
int finField = fin ? 0b10000000 : 0;
int opCode2 = i==0 ? opCode : 0;
byte firstByte= (byte)(finField | (opCode2 & 0xF));
int r=(i==noPackets-1?lengthlastByte : 4096);
var b=glenBytes(r);
byte[] message = new byte[1+b.Length + (fin?lengthlastByte : 4096)];
message[0]=firstByte;
Array.Copy(b,0,message,1,b.Length);
Array.Copy(msg.Data,i*4096,message,1+b.Length,r);
await context.NetworkStream.WriteAsync(message,0,message.Length);
}
semaphoreSlim.Release();
}
private async Task PongSend(byte[] msg,long len)
{
await semaphoreSlim.WaitAsync();
int finField = 0b10000000 ;
byte firstByte= (byte)(finField | 0xA);
var b=glenBytes(len);
byte[] message = new byte[1+b.Length + len];
message[0]=firstByte;
Array.Copy(b,0,message,1,b.Length);
Array.Copy(msg,0,message,1+b.Length,len);
await context.NetworkStream.WriteAsync(message,0,message.Length);
semaphoreSlim.Release();
}
private string get_Sec_WebSocketAccept(string headerVal)
{
if(string.IsNullOrWhiteSpace(headerVal))
{
return "";
}
string headerVal2 = $"{headerVal.Trim()}258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
var res=SHA1.Create().ComputeHash(Encoding.UTF8.GetBytes(headerVal2));
return Convert.ToBase64String(res);
}
private void MaskMessage(byte[] key,byte[] message)
{
if(key.Length < 4)
{
return;
}
for(int i = 0;i<message.Length;i++)
{
message[i] ^= key[i % 4];
}
}
private async Task<long> get_long()
{
byte[] data = new byte[8];
await context.NetworkStream.ReadAsync(data,0,data.Length);
if(BitConverter.IsLittleEndian)
{
Array.Reverse(data);
}
return BitConverter.ToInt64(data,0);
}
public async Task Ping(byte[] ping)
{
await semaphoreSlim.WaitAsync();
int finField = 0b10000000 ;
byte firstByte= (byte)(finField | 0x9);
var b=glenBytes(ping.Length);
byte[] message = new byte[1+b.Length + ping.Length];
message[0]=firstByte;
Array.Copy(b,0,message,1,b.Length);
Array.Copy(ping,0,message,1+b.Length,ping.Length);
await context.NetworkStream.WriteAsync(message,0,message.Length);
semaphoreSlim.Release();
}
private async Task<short> get_short()
{
byte[] data = new byte[2];
await context.NetworkStream.ReadAsync(data,0,data.Length);
if(BitConverter.IsLittleEndian)
{
Array.Reverse(data);
}
return BitConverter.ToInt16(data,0);
}
private async Task<(byte[] data,long len)> read_packet_async(byte len)
{
int realLen=len & 127;
bool masked=(len & 0b10000000) > 0;
long realLen2 = realLen >= 126 ? realLen > 126 ? await get_long() : await get_short() : realLen;
byte[] maskingKey = new byte[4];
if(masked)
{
await context.NetworkStream.ReadAsync(maskingKey,0,maskingKey.Length);
}
byte[] data = new byte[realLen2];
await context.NetworkStream.ReadAsync(data,0,data.Length);
if(masked)
{
MaskMessage(maskingKey,data);
}
return (data,realLen2);
}
public async Task StartAsync()
{
/*
GET /chatUrl HTTP/1.1
Host: server.example.com
Upgrade: websocket
*/
string sec_websocket_accept="";
if(context.RequestHeaders.TryGetFirst("Sec-WebSocket-Key",out sec_websocket_accept))
{
sec_websocket_accept=get_Sec_WebSocketAccept(sec_websocket_accept);
}else{
return;
}
if(!context.RequestHeaders.AnyEquals("Upgrade","websocket"))
{
//Console.WriteLine("Doesn't contain Upgrade: websocket");
return;
}
if(!context.RequestHeaders.AnyEquals("Sec-WebSocket-Version", "13"))
{
//Console.WriteLine("Doesn't contain version 13");
return;
}
context.StatusCode = 101;
context.ResponseHeaders.Add("Upgrade","websocket");
if(context.ResponseHeaders.ContainsKey("Connection"))
{
context.ResponseHeaders["Connection"].Clear();
}
context.ResponseHeaders.Add("Connection","Upgrade");
context.ResponseHeaders.Add("Sec-WebSocket-Accept",sec_websocket_accept);
await context.WriteHeadersAsync();
//await context.NetworkStream.FlushAsync();
await context.NetworkStream.FlushAsync();
hasInit=true;
bool isBinary=false;
MemoryStream strm=new MemoryStream();
while(context.Connected)
{
byte[] frame_start=new byte[2];
await context.NetworkStream.ReadAsync(frame_start,0,2);
byte first= frame_start[0];
bool hasMessage =false;
int opcode = first & 0xF;
bool fin = (first | 0b10000000) > 0;
switch(opcode)
{
case 0x0:
if(!hasMessage) break;
var (data,len)= await read_packet_async(frame_start[1]);
strm.Write(data,0,(int)len);
break;
case 0x1:
case 0x2:
hasMessage=true;
strm.Dispose();
strm=new MemoryStream();
isBinary = opcode == 0x2;
var (data2,len2)= await read_packet_async(frame_start[1]);
strm.Write(data2,0,(int)len2);
break;
case 0x8:
WebSocketClosed?.Invoke(this,new WebSocketClosedEventArgs(true));
return;
case 0x9:
var (data3,len3) =await read_packet_async(frame_start[1]);
await PongSend(data3,len3);
break;
case 0xA:
var (data4,len4) =await read_packet_async(frame_start[1]);
break;
}
if(fin && hasMessage)
{
hasMessage=false;
WebSocketMessage msg=new WebSocketMessage(strm.ToArray(),isBinary);
MessageArrived?.Invoke(this,new WebSocketMessageEventArgs(msg));
}
}
WebSocketClosed?.Invoke(this,new WebSocketClosedEventArgs(false));
}
}
}