#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_endian.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
#include <linux/icmp.h>
#include <linux/tcp.h>
#include <linux/in.h>
#define MAX_CHECKING 4
#define MAX_CSUM_WORDS 750
static __always_inline __u32 sum16(const void* data, __u32 size, const void* data_end) {
__u32 sum = 0;
const __u16 *ptr = (const __u16 *)data;
#pragma unroll
for (int i = 0; i < MAX_CSUM_WORDS; ++i) {
if ((const void *)(ptr + 1) > (data + size)) {
break;
}
if ((const void *)(ptr + 1) > data_end) {
return sum;
}
sum += *ptr;
ptr++;
}
// Handle the potential odd byte at the end if size is odd
if (size & 1) {
const __u8 *byte_ptr = (const __u8 *)ptr; // ptr is now after the last full word
// BPF Verifier check: Ensure the single byte read is within packet bounds
if ((const void *)(byte_ptr + 1) <= data_end && (const void *)byte_ptr < data_end) {
// In checksum calculation, the last odd byte is treated as the
// high byte of a 16-bit word, padded with a zero low byte.
// E.g., if the byte is 0xAB, it's treated as 0xAB00.
sum += (__u16)(*byte_ptr) << 8;
}
// If the bounds check fails, we just return the sum calculated so far.
}
return sum;
}
SEC("xdp")
int tcp_bounce(struct xdp_md *ctx) {
void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;
struct ethhdr *eth = data;
if ((void *)eth + sizeof(*eth) > data_end)
return XDP_PASS; // not enough data
if (eth->h_proto != bpf_htons(ETH_P_IP))
return XDP_PASS;
struct iphdr *iph = data + sizeof(*eth);
if ((void *)iph + sizeof(*iph) > data_end)
return XDP_PASS;
if (iph->protocol != IPPROTO_TCP)
return XDP_PASS;
//check ip len
int ip_hdr_len = iph->ihl*4;
if((void *)iph + ip_hdr_len > data_end)
return XDP_PASS;
// convert to TCP
struct tcphdr *tcph = (void *)iph + ip_hdr_len;
if ((void *)tcph + sizeof(*tcph) > data_end)
return XDP_PASS;
if (!(tcph->syn) || tcph->ack)
return XDP_DROP;
// swap MAC addresses
__u8 tmp_mac[ETH_ALEN];
__builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN);
__builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN);
__builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN);
// swap IP addresses
__be32 tmp_ip = iph->saddr;
iph->saddr = iph->daddr;
iph->daddr = tmp_ip;
// TCP
// swap port
__be16 tmpsrcport = tcph->source;
tcph->source = tcph->dest;
tcph->dest = tmpsrcport;
// syn+ack
tcph->ack = 1;
__u32 ack_seq = bpf_ntohl(tcph->seq) + 1;
tcph->ack_seq = bpf_htonl(ack_seq);
// checksum pseudo header
__u32 csum = 0;
tcph->check = (__be16)csum;
if ((void *)&iph->saddr + 8 > data_end)
return XDP_PASS;
csum = bpf_csum_diff(0, 0, (__be32 *)&iph->saddr, 8, csum);
__u16 tcp_len = bpf_ntohs(iph->tot_len) - ip_hdr_len;
csum += (__u32)(bpf_htons(IPPROTO_TCP) << 16) | bpf_htons(tcp_len);
csum += sum16(tcph, tcp_len, data_end);
while (csum >> 16)
csum = (csum & 0xFFFF) + (csum >> 16);
tcph->check = (__be16)~csum;
return XDP_TX;
}
char _license[] SEC("license") = "GPL";