#![deny(missing_docs)]
use std::{
collections::HashMap,
io,
time::Duration
};
use tokio::{
net::{ ToSocketAddrs, UdpSocket },
time::timeout as tokioTimeout,
};
pub const SESSION_ID_MASK: u32 = 0x0F0F0F0F;
const REQUEST_HEADER: [u8; 2] = [0xFE, 0xFD];
static mut SESSION_ID_COUNTER: u16 = 0;
fn gen_session_id() -> u32 {
unsafe {
SESSION_ID_COUNTER = SESSION_ID_COUNTER.wrapping_add(1);
}
let mut session_id_bytes = [0; 4];
for (i, b) in unsafe { SESSION_ID_COUNTER }.to_be_bytes().iter().enumerate() {
session_id_bytes[i * 2] = b >> 4;
session_id_bytes[i * 2 + 1] = b & 0x0F;
}
u32::from_be_bytes(session_id_bytes)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct MaybeIncompleteDataError(usize);
impl std::error::Error for MaybeIncompleteDataError {}
impl core::fmt::Display for MaybeIncompleteDataError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Buffer filled up")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum PacketType {
Handshake = 0x09,
Query = 0x00,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct HandshakeRequest {
session_id: u32,
}
impl HandshakeRequest {
const PACKET_TYPE: PacketType = PacketType::Handshake;
fn bytes(&self) -> Vec<u8> {
let mut vec = Vec::new();
vec.extend_from_slice(&REQUEST_HEADER);
vec.push(Self::PACKET_TYPE as u8);
vec.extend_from_slice(&self.session_id.to_be_bytes());
vec
}
async fn send(&self, socket: &UdpSocket, timeout: Duration) -> io::Result<HandshakeResponse> {
tokioTimeout(timeout, socket.send(self.bytes().as_slice())).await??;
let mut response_buffer = [0; 16];
let bytes_read = tokioTimeout(timeout, socket.recv(&mut response_buffer)).await??;
let mut response_buffer = response_buffer[..bytes_read].to_owned();
let res_packet_type = response_buffer.remove(0);
let mut res_session_id = [0; 4];
for (i, b) in response_buffer.drain(..4).enumerate() {
res_session_id[i] = b;
}
let res_session_id = u32::from_be_bytes(res_session_id);
let res_challenge_token =
String::from_utf8(
response_buffer
.into_iter()
.take_while(|x| x != &0x00)
.collect()
).map_err(|_| io::ErrorKind::InvalidData)?
.parse()
.map_err(|_| io::ErrorKind::InvalidData)?;
if res_packet_type != Self::PACKET_TYPE as u8 || res_session_id != self.session_id {
Err(io::ErrorKind::InvalidData.into())
} else {
Ok(HandshakeResponse {
session_id: res_session_id,
challenge_token: res_challenge_token,
})
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct HandshakeResponse {
session_id: u32,
challenge_token: i32,
}
impl From<HandshakeResponse> for BasicQueryRequest {
fn from(res: HandshakeResponse) -> Self {
BasicQueryRequest {
session_id: res.session_id,
challenge_token: res.challenge_token,
}
}
}
impl From<HandshakeResponse> for FullQueryRequest {
fn from(res: HandshakeResponse) -> Self {
FullQueryRequest {
session_id: res.session_id,
challenge_token: res.challenge_token,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct BasicQueryRequest {
session_id: u32,
challenge_token: i32,
}
impl BasicQueryRequest {
const PACKET_TYPE: PacketType = PacketType::Query;
fn bytes(&self) -> Vec<u8> {
let mut vec = Vec::new();
vec.extend_from_slice(&REQUEST_HEADER);
vec.push(Self::PACKET_TYPE as u8);
vec.extend_from_slice(&self.session_id.to_be_bytes());
vec.extend_from_slice(&self.challenge_token.to_be_bytes());
vec
}
async fn send(&self, socket: &UdpSocket, timeout: Duration, buffer: &mut [u8]) -> io::Result<BasicQueryResponse> {
tokioTimeout(timeout, socket.send(self.bytes().as_slice())).await??;
let bytes_read = tokioTimeout(timeout, socket.recv(buffer)).await??;
if bytes_read == buffer.len() {
return Err(io::Error::new(io::ErrorKind::Other, MaybeIncompleteDataError(bytes_read)));
}
let mut response_buffer: String = buffer[0..bytes_read].iter().map(|b| *b as char).collect();
let res_packet_type = response_buffer.remove(0) as u8;
let mut res_session_id = [0; 4];
for (i, b) in response_buffer.drain(..4).enumerate() {
res_session_id[i] = b as u8;
}
let res_session_id = u32::from_be_bytes(res_session_id);
let (
res_motd,
res_gametype,
res_map,
res_num_players,
res_max_players,
mut res_address
) = {
let mut i = 0;
let vec: Vec<_> = response_buffer
.split_terminator(|c| {
if i < 5 {
if c == '\0' {
i += 1;
}
c == '\0'
} else {
i += 1;
i > 7 && c == '\0'
}
})
.collect();
(
vec[0].clone(),
vec[1].clone(),
vec[2].clone(),
vec[3].clone(),
vec[4].clone(),
vec[5].chars().map(|c| c as u8).collect::<Vec<_>>(),
)
};
let res_num_players = res_num_players.parse().map_err(|_| io::ErrorKind::InvalidData)?;
let res_max_players = res_max_players.parse().map_err(|_| io::ErrorKind::InvalidData)?;
let mut res_port = [0; 2];
for (i, b) in res_address.drain(..2).enumerate() {
res_port[i] = b;
}
let res_port = u16::from_le_bytes(res_port);
let res_ip = String::from_utf8(res_address).map_err(|_| io::ErrorKind::InvalidData)?;
if res_packet_type != Self::PACKET_TYPE as u8 || res_session_id != self.session_id {
Err(io::ErrorKind::InvalidData.into())
} else {
Ok(BasicQueryResponse {
session_id: res_session_id,
motd: res_motd.to_owned(),
gametype: res_gametype.to_owned(),
map: res_map.to_owned(),
num_players: res_num_players,
max_players: res_max_players,
port: res_port,
ip: res_ip,
})
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct BasicQueryResponse {
session_id: u32,
pub motd: String,
pub gametype: String,
pub map: String,
pub num_players: u32,
pub max_players: u32,
pub port: u16,
pub ip: String,
}
struct FullQueryRequest {
session_id: u32,
challenge_token: i32,
}
impl FullQueryRequest {
const PACKET_TYPE: PacketType = PacketType::Query;
fn bytes(&self) -> Vec<u8> {
let mut vec = Vec::new();
vec.extend_from_slice(&REQUEST_HEADER);
vec.push(Self::PACKET_TYPE as u8);
vec.extend_from_slice(&self.session_id.to_be_bytes());
vec.extend_from_slice(&self.challenge_token.to_be_bytes());
vec.extend_from_slice(&[0; 4]);
vec
}
async fn send(&self, socket: &UdpSocket, timeout: Duration, buffer: &mut [u8]) -> io::Result<FullQueryResponse> {
tokioTimeout(timeout, socket.send(self.bytes().as_slice())).await??;
let bytes_read = tokioTimeout(timeout, socket.recv(buffer)).await??;
if bytes_read == buffer.len() {
return Err(io::Error::new(io::ErrorKind::Other, MaybeIncompleteDataError(bytes_read)));
}
let mut response_buffer = buffer[0..bytes_read].to_vec();
let res_packet_type = response_buffer.remove(0);
let mut res_session_id = [0; 4];
for (i, b) in response_buffer.drain(..4).enumerate() {
res_session_id[i] = b
}
let res_session_id = u32::from_be_bytes(res_session_id);
let mut res_kv = HashMap::<String, String>::new();
loop {
let key: Vec<u8> = response_buffer.iter()
.take_while(|b| **b != 0x00)
.map(|&b| b)
.collect();
response_buffer.drain(0..key.len() + 1);
if key.len() == 0 {
break;
}
let value: Vec<u8> = response_buffer.iter()
.take_while(|b| **b != 0x00)
.map(|&b| b)
.collect();
response_buffer.drain(0..value.len() + 1);
res_kv.insert(
key.iter()
.map(|b| *b as char)
.collect(),
value.iter()
.map(|b| *b as char)
.collect()
);
}
res_kv.remove("splitnum").ok_or(io::ErrorKind::InvalidData)?;
let (
res_software,
res_plugins
) = if let Some(mut plugins) = res_kv.get("plugins")
.map(|s| s.clone())
{
res_kv.remove("plugins");
let software: String = plugins.chars().take_while(|c| *c != ':').collect();
plugins.drain(0..software.len() + 1);
let plugins = if plugins.len() > 0 {
Some(plugins.split(';').map(|s| s.trim().to_owned()).collect())
} else {
None
};
(Some(software), plugins)
} else {
(None, None)
};
{
let player_header = "\x01player_\0\0";
if !response_buffer.starts_with(player_header.as_bytes()) {
return Err(io::ErrorKind::InvalidData.into());
} else {
response_buffer.drain(0..player_header.len());
}
}
let mut res_players = Vec::new();
loop {
let player: Vec<u8> = response_buffer.iter()
.take_while(|b| **b != 0x00)
.map(|&b| b)
.collect();
response_buffer.drain(0..player.len());
if player.len() == 0 {
break;
}
res_players.push(
player.iter()
.map(|b| *b as char)
.collect()
);
}
if res_packet_type != Self::PACKET_TYPE as u8 || res_session_id != self.session_id {
Err(io::ErrorKind::InvalidData.into())
} else {
Ok(FullQueryResponse {
session_id: res_session_id,
kv: res_kv,
players: res_players,
software: res_software,
plugins: res_plugins,
})
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FullQueryResponse {
session_id: u32,
pub kv: HashMap<String, String>,
pub players: Vec<String>,
pub software: Option<String>,
pub plugins: Option<Vec<String>>
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Query {
Basic(BasicQueryResponse),
Full(FullQueryResponse),
}
impl Query {
pub async fn get_basic<A: ToSocketAddrs>(address: A, timeout: Duration, buffer: &mut [u8]) -> io::Result<Self> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(address).await?;
let handshake_req = HandshakeRequest {
session_id: gen_session_id(),
};
let handshake_res = handshake_req.send(&socket, timeout).await?;
let query_req: BasicQueryRequest = handshake_res.into();
buffer.iter_mut().for_each(|b| *b = 0);
let query_res = query_req.send(&socket, timeout, buffer).await?;
Ok(Query::Basic(query_res))
}
pub async fn get_full<A: ToSocketAddrs>(address: A, timeout: Duration, buffer: &mut [u8]) -> io::Result<Self> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(address).await?;
let handshake_req = HandshakeRequest {
session_id: gen_session_id(),
};
let handshake_res = handshake_req.send(&socket, timeout).await?;
let query_req: FullQueryRequest = handshake_res.into();
buffer.iter_mut().for_each(|b| *b = 0);
let query_res = query_req.send(&socket, timeout, buffer).await?;
Ok(Query::Full(query_res))
}
}