| 1 | // websocket module implements websocket client and a websocket server |
| 2 | // attribution: @thecoderr the author of original websocket client |
| 3 | @[manualfree] |
| 4 | module websocket |
| 5 | |
| 6 | import net |
| 7 | import net.conv |
| 8 | import net.http |
| 9 | import net.ssl |
| 10 | import net.urllib |
| 11 | import time |
| 12 | import log |
| 13 | import rand |
| 14 | |
| 15 | const empty_bytearr = []u8{} |
| 16 | |
| 17 | pub struct ClientState { |
| 18 | pub mut: |
| 19 | state State = .closed // current state of connection |
| 20 | } |
| 21 | |
| 22 | // Client represents websocket client |
| 23 | pub struct Client { |
| 24 | is_server bool |
| 25 | mut: |
| 26 | ssl_conn &ssl.SSLConn = unsafe { nil } // secure connection used when wss is used |
| 27 | proxy_url string |
| 28 | flags []Flag // flags used in handshake |
| 29 | fragments []Fragment // current fragments |
| 30 | message_callbacks []MessageEventHandler // all callbacks on_message |
| 31 | error_callbacks []ErrorEventHandler // all callbacks on_error |
| 32 | open_callbacks []OpenEventHandler // all callbacks on_open |
| 33 | close_callbacks []CloseEventHandler // all callbacks on_close |
| 34 | pub: |
| 35 | is_ssl bool // true if secure socket is used |
| 36 | uri Uri // uri of current connection |
| 37 | id string // unique id of client |
| 38 | read_timeout i64 |
| 39 | write_timeout i64 |
| 40 | pub mut: |
| 41 | header http.Header // headers that will be passed when connecting |
| 42 | conn &net.TcpConn = unsafe { nil } // underlying TCP socket connection |
| 43 | nonce_size int = 16 // size of nounce used for masking |
| 44 | panic_on_callback bool // set to true of callbacks can panic |
| 45 | client_state shared ClientState // current state of connection |
| 46 | // logger used to log messages |
| 47 | logger &log.Logger = default_logger |
| 48 | resource_name string // name of current resource |
| 49 | last_pong_ut i64 // last time in unix time we got a pong message |
| 50 | } |
| 51 | |
| 52 | // Flag represents different types of headers in websocket handshake |
| 53 | enum Flag { |
| 54 | has_accept // Webs |
| 55 | has_connection |
| 56 | has_upgrade |
| 57 | } |
| 58 | |
| 59 | // State represents the state of the websocket connection. |
| 60 | pub enum State { |
| 61 | connecting = 0 |
| 62 | open |
| 63 | closing |
| 64 | closed |
| 65 | } |
| 66 | |
| 67 | // Message represents a whole message combined from 1 to n frames |
| 68 | pub struct Message { |
| 69 | pub: |
| 70 | opcode OPCode // websocket frame type of this message |
| 71 | payload []u8 // payload of the message |
| 72 | } |
| 73 | |
| 74 | // OPCode represents the supported websocket frame types |
| 75 | pub enum OPCode { |
| 76 | continuation = 0x00 |
| 77 | text_frame = 0x01 |
| 78 | binary_frame = 0x02 |
| 79 | close = 0x08 |
| 80 | ping = 0x09 |
| 81 | pong = 0x0A |
| 82 | } |
| 83 | |
| 84 | @[params] |
| 85 | pub struct ClientOpt { |
| 86 | pub: |
| 87 | read_timeout i64 = net.infinite_timeout |
| 88 | write_timeout i64 = 30 * time.second |
| 89 | logger &log.Logger = default_logger |
| 90 | proxy_url string // optional proxy URL used to open the websocket TCP tunnel |
| 91 | } |
| 92 | |
| 93 | // new_client instance a new websocket client |
| 94 | pub fn new_client(address string, opt ClientOpt) !&Client { |
| 95 | uri := parse_uri(address)! |
| 96 | return &Client{ |
| 97 | conn: unsafe { nil } |
| 98 | is_server: false |
| 99 | ssl_conn: ssl.new_ssl_conn()! |
| 100 | is_ssl: address.starts_with('wss') |
| 101 | logger: opt.logger |
| 102 | uri: uri |
| 103 | proxy_url: opt.proxy_url |
| 104 | client_state: ClientState{ |
| 105 | state: .closed |
| 106 | } |
| 107 | id: rand.uuid_v4() |
| 108 | header: http.new_header() |
| 109 | read_timeout: opt.read_timeout |
| 110 | write_timeout: opt.write_timeout |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | // connect connects to remote websocket server |
| 115 | pub fn (mut ws Client) connect() ! { |
| 116 | ws.assert_not_connected()! |
| 117 | ws.set_state(.connecting) |
| 118 | ws.logger.info('connecting to host ${ws.uri}') |
| 119 | ws.conn = ws.dial_socket()! |
| 120 | ws.handshake()! |
| 121 | ws.set_state(.open) |
| 122 | ws.logger.info('successfully connected to host ${ws.uri}') |
| 123 | ws.send_open_event() |
| 124 | } |
| 125 | |
| 126 | // listen listens and processes incoming messages |
| 127 | pub fn (mut ws Client) listen() ! { |
| 128 | mut log_msg := 'Starting client listener, server(${ws.is_server})...' |
| 129 | ws.logger.info(log_msg) |
| 130 | unsafe { log_msg.free() } |
| 131 | defer { |
| 132 | ws.logger.info('Quit client listener, server(${ws.is_server})...') |
| 133 | if ws.get_state() == .open { |
| 134 | ws.close(1000, 'closed by client') or {} |
| 135 | } |
| 136 | } |
| 137 | for ws.get_state() == .open { |
| 138 | msg := ws.read_next_message() or { |
| 139 | if err.code() == net.error_eintr { // Check for EINTR and retry |
| 140 | continue |
| 141 | } else if ws.get_state() in [.closed, .closing] { |
| 142 | return |
| 143 | } else { |
| 144 | ws.debug_log('failed to read next message: ${err}') |
| 145 | ws.send_error_event('failed to read next message: ${err}') |
| 146 | return err |
| 147 | } |
| 148 | } |
| 149 | if ws.get_state() in [.closed, .closing] { |
| 150 | return |
| 151 | } |
| 152 | ws.debug_log('got message: ${msg.opcode}') |
| 153 | match msg.opcode { |
| 154 | .text_frame { |
| 155 | ws.debug_log('read: text') |
| 156 | ws.send_message_event(msg) |
| 157 | unsafe { msg.free() } |
| 158 | } |
| 159 | .binary_frame { |
| 160 | ws.debug_log('read: binary') |
| 161 | ws.send_message_event(msg) |
| 162 | unsafe { msg.free() } |
| 163 | } |
| 164 | .ping { |
| 165 | ws.debug_log('read: ping, sending pong') |
| 166 | ws.send_control_frame(.pong, 'PONG', msg.payload) or { |
| 167 | ws.logger.error('error in message callback sending PONG: ${err}') |
| 168 | ws.send_error_event('error in message callback sending PONG: ${err}') |
| 169 | if ws.panic_on_callback { |
| 170 | panic(err) |
| 171 | } |
| 172 | continue |
| 173 | } |
| 174 | if msg.payload.len > 0 { |
| 175 | unsafe { msg.free() } |
| 176 | } |
| 177 | } |
| 178 | .pong { |
| 179 | ws.debug_log('read: pong') |
| 180 | ws.last_pong_ut = time.now().unix() |
| 181 | ws.send_message_event(msg) |
| 182 | if msg.payload.len > 0 { |
| 183 | unsafe { msg.free() } |
| 184 | } |
| 185 | } |
| 186 | .close { |
| 187 | ws.debug_log('read: close') |
| 188 | defer { |
| 189 | ws.manage_clean_close() |
| 190 | } |
| 191 | if msg.payload.len > 0 { |
| 192 | if msg.payload.len == 1 { |
| 193 | ws.close(1002, 'close payload cannot be 1 byte')! |
| 194 | return error('close payload cannot be 1 byte') |
| 195 | } |
| 196 | code := u16(msg.payload[0]) << 8 | u16(msg.payload[1]) |
| 197 | if code in invalid_close_codes { |
| 198 | ws.close(1002, 'invalid close code: ${code}')! |
| 199 | return error('invalid close code: ${code}') |
| 200 | } |
| 201 | reason := if msg.payload.len > 2 { msg.payload[2..] } else { []u8{} } |
| 202 | if reason.len > 0 { |
| 203 | ws.validate_utf_8(.close, reason)! |
| 204 | } |
| 205 | if ws.get_state() !in [.closing, .closed] { |
| 206 | // sending close back according to spec |
| 207 | ws.debug_log('close with reason, code: ${code}, reason: ${reason}') |
| 208 | r := reason.bytestr() |
| 209 | ws.close(code, r)! |
| 210 | } |
| 211 | unsafe { msg.free() } |
| 212 | } else { |
| 213 | if ws.get_state() !in [.closing, .closed] { |
| 214 | ws.debug_log('close with reason, no code') |
| 215 | // sending close back according to spec |
| 216 | ws.close(1000, 'normal')! |
| 217 | } |
| 218 | unsafe { msg.free() } |
| 219 | } |
| 220 | return |
| 221 | } |
| 222 | .continuation { |
| 223 | ws.logger.error('unexpected opcode continuation, nothing to continue') |
| 224 | ws.send_error_event('unexpected opcode continuation, nothing to continue') |
| 225 | ws.close(1002, 'nothing to continue')! |
| 226 | return error('unexpected opcode continuation, nothing to continue') |
| 227 | } |
| 228 | } |
| 229 | } |
| 230 | } |
| 231 | |
| 232 | // manage_clean_close closes connection in a clean websocket way |
| 233 | fn (mut ws Client) manage_clean_close() { |
| 234 | ws.send_close_event(1000, 'closed by client') |
| 235 | } |
| 236 | |
| 237 | // ping sends ping message to server |
| 238 | pub fn (mut ws Client) ping() ! { |
| 239 | ws.send_control_frame(.ping, 'PING', [])! |
| 240 | } |
| 241 | |
| 242 | // pong sends pong message to server, |
| 243 | pub fn (mut ws Client) pong() ! { |
| 244 | ws.send_control_frame(.pong, 'PONG', [])! |
| 245 | } |
| 246 | |
| 247 | // write_ptr writes len bytes provided a byteptr with a websocket messagetype |
| 248 | pub fn (mut ws Client) write_ptr(bytes &u8, payload_len int, code OPCode) !int { |
| 249 | // ws.debug_log('write_ptr code: ${code}') |
| 250 | if ws.get_state() != .open || ws.conn.sock.handle < 1 { |
| 251 | // todo: send error here later |
| 252 | return error('trying to write on a closed socket!') |
| 253 | } |
| 254 | mut header_len := 2 + if payload_len > 125 { 2 } else { 0 } + |
| 255 | if payload_len > 0xffff { 6 } else { 0 } |
| 256 | if !ws.is_server { |
| 257 | header_len += 4 |
| 258 | } |
| 259 | mut header := []u8{len: header_len, init: `0`} // [`0`].repeat(header_len) |
| 260 | header[0] = u8(int(code)) | 0x80 |
| 261 | masking_key := create_masking_key() |
| 262 | if ws.is_server { |
| 263 | if payload_len <= 125 { |
| 264 | header[1] = u8(payload_len) |
| 265 | } else if payload_len > 125 && payload_len <= 0xffff { |
| 266 | len16 := conv.hton16(u16(payload_len)) |
| 267 | header[1] = 126 |
| 268 | unsafe { vmemcpy(&header[2], &len16, 2) } |
| 269 | } else if payload_len > 0xffff && payload_len <= 0x7fffffff { |
| 270 | len_bytes := htonl64(u64(payload_len)) |
| 271 | header[1] = 127 |
| 272 | unsafe { vmemcpy(&header[2], len_bytes.data, 8) } |
| 273 | } |
| 274 | } else { |
| 275 | if payload_len <= 125 { |
| 276 | header[1] = u8(payload_len | 0x80) |
| 277 | header[2] = masking_key[0] |
| 278 | header[3] = masking_key[1] |
| 279 | header[4] = masking_key[2] |
| 280 | header[5] = masking_key[3] |
| 281 | } else if payload_len > 125 && payload_len <= 0xffff { |
| 282 | len16 := conv.hton16(u16(payload_len)) |
| 283 | header[1] = (126 | 0x80) |
| 284 | unsafe { vmemcpy(&header[2], &len16, 2) } |
| 285 | header[4] = masking_key[0] |
| 286 | header[5] = masking_key[1] |
| 287 | header[6] = masking_key[2] |
| 288 | header[7] = masking_key[3] |
| 289 | } else if payload_len > 0xffff && payload_len <= 0x7fffffff { |
| 290 | len64 := htonl64(u64(payload_len)) |
| 291 | header[1] = (127 | 0x80) |
| 292 | unsafe { vmemcpy(&header[2], len64.data, 8) } |
| 293 | header[10] = masking_key[0] |
| 294 | header[11] = masking_key[1] |
| 295 | header[12] = masking_key[2] |
| 296 | header[13] = masking_key[3] |
| 297 | } else { |
| 298 | ws.close(1009, 'frame too large')! |
| 299 | return error('frame too large') |
| 300 | } |
| 301 | } |
| 302 | len := header.len + payload_len |
| 303 | mut frame_buf := []u8{len: len} |
| 304 | unsafe { |
| 305 | vmemcpy(&frame_buf[0], &u8(header.data), header.len) |
| 306 | if payload_len > 0 { |
| 307 | vmemcpy(&frame_buf[header.len], bytes, payload_len) |
| 308 | } |
| 309 | } |
| 310 | if !ws.is_server { |
| 311 | for i in 0 .. payload_len { |
| 312 | frame_buf[header_len + i] ^= masking_key[i % 4] & 0xff |
| 313 | } |
| 314 | } |
| 315 | written_len := ws.socket_write(frame_buf)! |
| 316 | unsafe { |
| 317 | frame_buf.free() |
| 318 | masking_key.free() |
| 319 | header.free() |
| 320 | } |
| 321 | return written_len |
| 322 | } |
| 323 | |
| 324 | // write writes a byte array with a websocket messagetype to socket |
| 325 | pub fn (mut ws Client) write(bytes []u8, code OPCode) !int { |
| 326 | return ws.write_ptr(&u8(bytes.data), bytes.len, code) |
| 327 | } |
| 328 | |
| 329 | // write_str, writes a string with a websocket texttype to socket |
| 330 | pub fn (mut ws Client) write_string(str string) !int { |
| 331 | return ws.write_ptr(str.str, str.len, .text_frame) |
| 332 | } |
| 333 | |
| 334 | // close closes the websocket connection |
| 335 | pub fn (mut ws Client) close(code int, message string) ! { |
| 336 | ws.debug_log('sending close, ${code}, ${message}') |
| 337 | ws_state := ws.get_state() |
| 338 | if ws_state in [.closed, .closing] || ws.conn.sock.handle <= 1 { |
| 339 | ws.debug_log('close: Websocket already closed (${ws_state}), ${message}, ${code} handle(${ws.conn.sock.handle})') |
| 340 | err_msg := 'Socket already closed: ${code}' |
| 341 | return error(err_msg) |
| 342 | } |
| 343 | defer { |
| 344 | ws.shutdown_socket() or {} |
| 345 | ws.reset_state() or {} |
| 346 | ws.send_close_event(code, message) |
| 347 | } |
| 348 | ws.set_state(.closing) |
| 349 | // mut code32 := 0 |
| 350 | if code > 0 { |
| 351 | code_ := conv.hton16(u16(code)) |
| 352 | message_len := message.len + 2 |
| 353 | mut close_frame := []u8{len: message_len} |
| 354 | close_frame[0] = u8(code_ & 0xFF) |
| 355 | close_frame[1] = u8(code_ >> 8) |
| 356 | // code32 = (close_frame[0] << 8) + close_frame[1] |
| 357 | for i in 0 .. message.len { |
| 358 | close_frame[i + 2] = message[i] |
| 359 | } |
| 360 | ws.send_control_frame(.close, 'CLOSE', close_frame)! |
| 361 | unsafe { close_frame.free() } |
| 362 | } else { |
| 363 | ws.send_control_frame(.close, 'CLOSE', [])! |
| 364 | } |
| 365 | ws.fragments = [] |
| 366 | } |
| 367 | |
| 368 | // send_control_frame sends a control frame to the server |
| 369 | fn (mut ws Client) send_control_frame(code OPCode, frame_typ string, payload []u8) ! { |
| 370 | ws.debug_log('send control frame ${code}, frame_type: ${frame_typ}') |
| 371 | if ws.get_state() !in [.open, .closing] && ws.conn.sock.handle > 1 { |
| 372 | return error('socket is not connected') |
| 373 | } |
| 374 | header_len := if ws.is_server { 2 } else { 6 } |
| 375 | frame_len := header_len + payload.len |
| 376 | mut control_frame := []u8{len: frame_len} |
| 377 | mut masking_key := if !ws.is_server { create_masking_key() } else { empty_bytearr } |
| 378 | defer { |
| 379 | unsafe { |
| 380 | control_frame.free() |
| 381 | if masking_key.len > 0 { |
| 382 | masking_key.free() |
| 383 | } |
| 384 | } |
| 385 | } |
| 386 | control_frame[0] = u8(int(code) | 0x80) |
| 387 | if !ws.is_server { |
| 388 | control_frame[1] = u8(payload.len | 0x80) |
| 389 | control_frame[2] = masking_key[0] |
| 390 | control_frame[3] = masking_key[1] |
| 391 | control_frame[4] = masking_key[2] |
| 392 | control_frame[5] = masking_key[3] |
| 393 | } else { |
| 394 | control_frame[1] = u8(payload.len) |
| 395 | } |
| 396 | if code == .close { |
| 397 | if payload.len >= 2 { |
| 398 | if !ws.is_server { |
| 399 | mut parsed_payload := []u8{len: payload.len + 1} |
| 400 | unsafe { vmemcpy(parsed_payload.data, &payload[0], payload.len) } |
| 401 | parsed_payload[payload.len] = `\0` |
| 402 | for i in 0 .. payload.len { |
| 403 | control_frame[6 + i] = (parsed_payload[i] ^ masking_key[i % 4]) & 0xff |
| 404 | } |
| 405 | unsafe { parsed_payload.free() } |
| 406 | } else { |
| 407 | unsafe { vmemcpy(&control_frame[2], &payload[0], payload.len) } |
| 408 | } |
| 409 | } |
| 410 | } else { |
| 411 | if !ws.is_server { |
| 412 | if payload.len > 0 { |
| 413 | for i in 0 .. payload.len { |
| 414 | control_frame[header_len + i] = (payload[i] ^ masking_key[i % 4]) & 0xff |
| 415 | } |
| 416 | } |
| 417 | } else { |
| 418 | if payload.len > 0 { |
| 419 | unsafe { vmemcpy(&control_frame[2], &payload[0], payload.len) } |
| 420 | } |
| 421 | } |
| 422 | } |
| 423 | ws.socket_write(control_frame) or { |
| 424 | return error('send_control_frame: error sending ${frame_typ} control frame.') |
| 425 | } |
| 426 | } |
| 427 | |
| 428 | // parse_uri parses the url to a Uri |
| 429 | fn parse_uri(url string) !&Uri { |
| 430 | u := urllib.parse(url)! |
| 431 | request_uri := u.request_uri() |
| 432 | v := request_uri.split('?') |
| 433 | mut port := u.port() |
| 434 | uri := u.str() |
| 435 | if port == '' { |
| 436 | port = if uri.starts_with('ws://') { |
| 437 | '80' |
| 438 | } else if uri.starts_with('wss://') { |
| 439 | '443' |
| 440 | } else { |
| 441 | u.port() |
| 442 | } |
| 443 | } |
| 444 | querystring := if v.len > 1 { '?' + v[1] } else { '' } |
| 445 | return &Uri{ |
| 446 | url: url |
| 447 | hostname: u.hostname() |
| 448 | port: port |
| 449 | resource: v[0] |
| 450 | querystring: querystring |
| 451 | } |
| 452 | } |
| 453 | |
| 454 | // set_state sets current state of the websocket connection |
| 455 | pub fn (mut ws Client) set_state(state State) { |
| 456 | lock ws.client_state { |
| 457 | ws.client_state.state = state |
| 458 | } |
| 459 | } |
| 460 | |
| 461 | // get_state return the current state of the websocket connection |
| 462 | pub fn (ws &Client) get_state() State { |
| 463 | return rlock ws.client_state { |
| 464 | ws.client_state.state |
| 465 | } |
| 466 | } |
| 467 | |
| 468 | // assert_not_connected returns error if the connection is not connected |
| 469 | fn (ws &Client) assert_not_connected() ! { |
| 470 | match ws.get_state() { |
| 471 | .connecting { return error('connect: websocket is connecting') } |
| 472 | .open { return error('connect: websocket already open') } |
| 473 | .closing { return error('connect: reconnect on closing websocket not supported, please use new client') } |
| 474 | else {} |
| 475 | } |
| 476 | } |
| 477 | |
| 478 | // reset_state resets the websocket and initialize default settings |
| 479 | pub fn (mut ws Client) reset_state() ! { |
| 480 | lock ws.client_state { |
| 481 | ws.client_state.state = .closed |
| 482 | ws.ssl_conn = ssl.new_ssl_conn()! |
| 483 | ws.flags = [] |
| 484 | ws.fragments = [] |
| 485 | } |
| 486 | } |
| 487 | |
| 488 | // debug_log handles debug logging output for client and server |
| 489 | fn (mut ws Client) debug_log(text string) { |
| 490 | if ws.is_server { |
| 491 | ws.logger.debug('server-> ${text}') |
| 492 | } else { |
| 493 | ws.logger.debug('client-> ${text}') |
| 494 | } |
| 495 | } |
| 496 | |
| 497 | // free handles manual free memory of Message struct |
| 498 | pub fn (m &Message) free() { |
| 499 | unsafe { m.payload.free() } |
| 500 | } |
| 501 | |
| 502 | // free handles manual free memory of Client struct |
| 503 | pub fn (c &Client) free() { |
| 504 | unsafe { |
| 505 | c.flags.free() |
| 506 | c.fragments.free() |
| 507 | c.message_callbacks.free() |
| 508 | c.error_callbacks.free() |
| 509 | c.open_callbacks.free() |
| 510 | c.close_callbacks.free() |
| 511 | c.header.free() |
| 512 | } |
| 513 | } |
| 514 | |