diff --git a/src/core/hle/service/nwm/nwm_uds.cpp b/src/core/hle/service/nwm/nwm_uds.cpp index c08b9626e..437c9d321 100644 --- a/src/core/hle/service/nwm/nwm_uds.cpp +++ b/src/core/hle/service/nwm/nwm_uds.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include "common/common_types.h" #include "common/logging/log.h" @@ -79,7 +80,11 @@ static u8 network_channel = DefaultNetworkChannel; static NetworkInfo network_info; // Mapping of mac addresses to their respective node_ids. -static std::map node_map; +struct Node { + bool connected; + u16 node_id; +}; +static std::map node_map; // Event that will generate and send the 802.11 beacon frames. static CoreTiming::EventType* beacon_broadcast_event; @@ -107,6 +112,9 @@ static std::list received_beacons; // Network node id used when a SecureData packet is addressed to every connected node. constexpr u16 BroadcastNetworkNodeId = 0xFFFF; +// The Host has always dest_node_id 1 +constexpr u16 HostDestNodeId = 1; + /** * Returns a list of received 802.11 beacon frames from the specified sender since the last call. */ @@ -159,16 +167,20 @@ static void BroadcastNodeMap() { packet.channel = network_channel; packet.type = Network::WifiPacket::PacketType::NodeMap; packet.destination_address = Network::BroadcastMac; - std::size_t size = node_map.size(); + std::size_t num_entries = std::count_if(node_map.begin(), node_map.end(), + [](const auto& node) { return node.second.connected; }); using node_t = decltype(node_map)::value_type; - packet.data.resize(sizeof(size) + (sizeof(node_t::first) + sizeof(node_t::second)) * size); - std::memcpy(packet.data.data(), &size, sizeof(size)); - std::size_t offset = sizeof(size); + packet.data.resize(sizeof(num_entries) + + (sizeof(node_t::first) + sizeof(node_t::second.node_id)) * num_entries); + std::memcpy(packet.data.data(), &num_entries, sizeof(num_entries)); + std::size_t offset = sizeof(num_entries); for (const auto& node : node_map) { - std::memcpy(packet.data.data() + offset, node.first.data(), sizeof(node.first)); - std::memcpy(packet.data.data() + offset + sizeof(node.first), &node.second, - sizeof(node.second)); - offset += sizeof(node.first) + sizeof(node.second); + if (node.second.connected) { + std::memcpy(packet.data.data() + offset, node.first.data(), sizeof(node.first)); + std::memcpy(packet.data.data() + offset + sizeof(node.first), &node.second.node_id, + sizeof(node.second.node_id)); + offset += sizeof(node.first) + sizeof(node.second.node_id); + } } SendPacket(packet); @@ -176,6 +188,11 @@ static void BroadcastNodeMap() { static void HandleNodeMapPacket(const Network::WifiPacket& packet) { std::lock_guard lock(connection_status_mutex); + if (connection_status.status == static_cast(NetworkStatus::ConnectedAsHost)) { + LOG_DEBUG(Service_NWM, "Ignored NodeMapPacket since connection_status is host"); + return; + } + node_map.clear(); std::size_t num_entries; Network::MacAddress address; @@ -185,7 +202,8 @@ static void HandleNodeMapPacket(const Network::WifiPacket& packet) { for (std::size_t i = 0; i < num_entries; ++i) { std::memcpy(&address, packet.data.data() + offset, sizeof(address)); std::memcpy(&id, packet.data.data() + offset + sizeof(address), sizeof(id)); - node_map[address] = id; + node_map[address].connected = true; + node_map[address].node_id = id; offset += sizeof(address) + sizeof(id); } } @@ -218,7 +236,12 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) { "Could not join network"); { std::lock_guard lock(connection_status_mutex); - ASSERT(connection_status.status == static_cast(NetworkStatus::Connecting)); + if (connection_status.status != static_cast(NetworkStatus::Connecting)) { + LOG_DEBUG(Service_NWM, + "Ignored AssociationResponseFrame because connection status is {}", + connection_status.status); + return; + } } // Send the EAPoL-Start packet to the server. @@ -245,14 +268,21 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { return; } - auto node = DeserializeNodeInfoFromFrame(packet.data); - - if (connection_status.max_nodes == connection_status.total_nodes) { - // Reject connection attempt - LOG_ERROR(Service_NWM, "Reached maximum nodes, but reject packet wasn't sent."); - // TODO(B3N30): Figure out what packet is sent here + auto node_it = node_map.find(packet.transmitter_address); + if (node_it == node_map.end()) { + LOG_DEBUG(Service_NWM, "Connection sequence aborted, because the AuthenticationFrame " + "of the client wasn't recieved"); return; } + if (node_it->second.connected) { + LOG_DEBUG(Service_NWM, + "Connection sequence aborted, because the client is already connected"); + return; + } + + ASSERT(connection_status.max_nodes != connection_status.total_nodes); + + auto node = DeserializeNodeInfoFromFrame(packet.data); // Get an unused network node id u16 node_id = GetNextAvailableNodeId(); @@ -268,7 +298,8 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { network_info.total_nodes++; - node_map[packet.transmitter_address] = node.network_node_id; + node_map[packet.transmitter_address].node_id = node.network_node_id; + node_map[packet.transmitter_address].connected = true; BroadcastNodeMap(); @@ -321,6 +352,7 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { connection_status_event->Signal(); connection_event->Signal(); } else if (connection_status.status == static_cast(NetworkStatus::ConnectedAsClient)) { + // TODO(B3N30): Remove that section and send/receive a proper connection_status packet // On a 3ds this packet wouldn't be addressed to already connected clients // We use this information because in the current implementation the host // isn't broadcasting the node information @@ -349,6 +381,14 @@ static void HandleSecureDataPacket(const Network::WifiPacket& packet) { std::unique_lock lock(connection_status_mutex, std::defer_lock); std::lock(hle_lock, lock); + if (connection_status.status != static_cast(NetworkStatus::ConnectedAsHost) && + connection_status.status != static_cast(NetworkStatus::ConnectedAsClient)) { + // TODO(B3N30): Handle spectators + LOG_DEBUG(Service_NWM, "Ignored SecureDataPacket, because connection status is {}", + connection_status.status); + return; + } + if (secure_data.src_node_id == connection_status.network_node_id) { // Ignore packets that came from ourselves. return; @@ -464,12 +504,24 @@ void HandleAuthenticationFrame(const Network::WifiPacket& packet) { connection_status.status); return; } + if (node_map.find(packet.transmitter_address) != node_map.end()) { + LOG_ERROR(Service_NWM, "Connection sequence aborted, because there is already a " + "connected client with that MAC-Adress"); + return; + } + if (connection_status.max_nodes == connection_status.total_nodes) { + // Reject connection attempt + LOG_ERROR(Service_NWM, "Reached maximum nodes, but reject packet wasn't sent."); + // TODO(B3N30): Figure out what packet is sent here + return; + } // Respond with an authentication response frame with SEQ2 auth_request.channel = network_channel; auth_request.data = GenerateAuthenticationFrame(AuthenticationSeq::SEQ2); auth_request.destination_address = packet.transmitter_address; auth_request.type = WifiPacket::PacketType::Authentication; + node_map[packet.transmitter_address].connected = false; } SendPacket(auth_request); @@ -492,17 +544,29 @@ void HandleDeauthenticationFrame(const Network::WifiPacket& packet) { return; } - u16 node_id = node_map[packet.transmitter_address]; - auto node = std::find_if(node_info.begin(), node_info.end(), [&node_id](const NodeInfo& info) { - return info.network_node_id == node_id; - }); - ASSERT(node != node_info.end()); + Node node = node_map[packet.transmitter_address]; + node_map.erase(packet.transmitter_address); - connection_status.node_bitmask &= ~(1 << (node_id - 1)); - connection_status.changed_nodes |= 1 << (node_id - 1); + if (!node.connected) { + LOG_DEBUG(Service_NWM, "Received DeauthenticationFrame from a not connected MAC Address"); + return; + } + + auto node_it = std::find_if(node_info.begin(), node_info.end(), [&node](const NodeInfo& info) { + return info.network_node_id == node.node_id; + }); + ASSERT(node_it != node_info.end()); + + connection_status.node_bitmask &= ~(1 << (node.node_id - 1)); + connection_status.changed_nodes |= 1 << (node.node_id - 1); connection_status.total_nodes--; + connection_status.nodes[node.node_id - 1] = 0; network_info.total_nodes--; + // TODO(B3N30): broadcast new connection_status to clients + + node_it->Reset(); + connection_status_event->Signal(); } @@ -541,6 +605,26 @@ void OnWifiPacketReceived(const Network::WifiPacket& packet) { } } +static boost::optional GetNodeMacAddress(u16 dest_node_id, u8 flags) { + constexpr u8 BroadcastFlag = 0x2; + if ((flags & BroadcastFlag) || dest_node_id == BroadcastNetworkNodeId) { + // Broadcast + return Network::BroadcastMac; + } else if (dest_node_id == HostDestNodeId) { + // Destination is host + return network_info.host_mac_address; + } + // Destination is a specific client + auto destination = + std::find_if(node_map.begin(), node_map.end(), [dest_node_id](const auto& node) { + return node.second.node_id == dest_node_id && node.second.connected; + }); + if (destination == node_map.end()) { + return {}; + } + return destination->first; +} + void NWM_UDS::Shutdown(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x03, 0, 0); @@ -656,6 +740,7 @@ void NWM_UDS::InitializeWithVersion(Kernel::HLERequestContext& ctx) { connection_status.status = static_cast(NetworkStatus::NotConnected); node_info.clear(); node_info.push_back(current_node); + channel_data.clear(); } IPC::RequestBuilder rb = rp.MakeBuilder(1, 2); @@ -1000,30 +1085,15 @@ void NWM_UDS::SendTo(Kernel::HLERequestContext& ctx) { return; } - Network::MacAddress dest_address; - if (flags >> 2) { LOG_ERROR(Service_NWM, "Unexpected flags 0x{:02X}", flags); } - if ((flags & (0x1 << 1)) || dest_node_id == 0xFFFF) { - // Broadcast - dest_address = Network::BroadcastMac; - } else if (dest_node_id != 1) { - // Send to specific client - auto destination = - std::find_if(node_map.begin(), node_map.end(), - [dest_node_id](const auto& node) { return node.second == dest_node_id; }); - if (destination == node_map.end()) { - LOG_ERROR(Service_NWM, "tried to send packet to unknown dest id {}", dest_node_id); - rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS, - ErrorSummary::WrongArgument, ErrorLevel::Status)); - return; - } - dest_address = destination->first; - } else { - // Send message to host - dest_address = network_info.host_mac_address; + auto dest_address = GetNodeMacAddress(dest_node_id, flags); + if (!dest_address) { + rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Status)); + return; } constexpr std::size_t MaxSize = 0x5C6; @@ -1039,12 +1109,12 @@ void NWM_UDS::SendTo(Kernel::HLERequestContext& ctx) { GenerateDataPayload(input_buffer, data_channel, dest_node_id, connection_status.network_node_id, sequence_number); - // TODO(B3N30): Retrieve the MAC address of the dest_node_id and our own to encrypt + // TODO(B3N30): Use the MAC address of the dest_node_id and our own to encrypt // and encapsulate the payload. Network::WifiPacket packet; - packet.destination_address = dest_address; + packet.destination_address = *dest_address; packet.channel = network_channel; packet.data = std::move(data_payload); packet.type = Network::WifiPacket::PacketType::Data; diff --git a/src/core/hle/service/nwm/nwm_uds.h b/src/core/hle/service/nwm/nwm_uds.h index d10b99fc9..2c79aa674 100644 --- a/src/core/hle/service/nwm/nwm_uds.h +++ b/src/core/hle/service/nwm/nwm_uds.h @@ -36,6 +36,12 @@ struct NodeInfo { INSERT_PADDING_BYTES(4); u16_le network_node_id; INSERT_PADDING_BYTES(6); + + void Reset() { + friend_code_seed = 0; + username.fill(0); + network_node_id = 0; + } }; static_assert(sizeof(NodeInfo) == 40, "NodeInfo has incorrect size.");