ping 是一个经常被用来检查主机间连通性的工具, 它基于 ICMP 协议实现, 基本原理很简单: 本机给远程机器发送 ICMP 报文, 远程主机接收到 ICMP 报文后便会回复一个类似的 ICMP 报文; 当本机接收到回复后变认为远程主机是可连接的, 否则便认为这个主机是不可达的.

为了了解 golang 的网络编程, 我用 go 实现了一个 ping 命令, 本文会介绍如何实现 ping 命令.

Demo

这里有完整的示例代码, 可以直接执行实现下面的效果 (注意需要 sudo 权限):

➜  ping git:(master) sudo go run goping.go baidu.com
Ping 111.13.101.208 (baidu.com):

28 bytes from 111.13.101.208: seq=1 time=9ms
28 bytes from 111.13.101.208: seq=2 time=9ms
28 bytes from 111.13.101.208: seq=3 time=10ms
28 bytes from 111.13.101.208: seq=4 time=10ms
28 bytes from 111.13.101.208: seq=5 time=9ms

如何实现

ICMP 报文

首先我们需要定义出 ICMP 报文头的结构:

type ICMP struct {
Type uint8
Code uint8
CheckSum uint16
Identifier uint16
SequenceNum uint16

其中 Type 表明的是 ICMP 的类型, Code 则用来进一步划分 ICMP 的类型, ping 使用的是 echo 类型的 ICMP, 这两个值需要分别设置为 8 和 0.

CheckSum 是报文头的校验值, 以防止在网络传输过程中的数据错误. 会先把这个字段设置为 0 来计算校验值, 计算完成后再把校验值赋值到这个字段.

ID 是用来标识一个 ICMP, 可以设置为 0; 而 SequenceNum 则是序列号, 可以在发送 ICMP 报文的时候依次累加.

这篇文章对 ICMP 的结构有更详细的介绍.

基于上面的描述, 我们可以实现下面这个基于序列号生成 ICMP 报文头的函数:

func getICMP(seq uint16) ICMP {
icmp := ICMP{
Type: 8,
Code: 0,
CheckSum: 0,
Identifier: 0,
SequenceNum: seq,
}

var buffer bytes.Buffer
binary.Write(&buffer, binary.BigEndian, icmp)
icmp.CheckSum = CheckSum(buffer.Bytes())
buffer.Reset()

return icmp
}

其中 CheckSum() 是用来计算校验值的函数. 在网络中传输的数据需要是大端字节序的.

发送及接收 ICMP 报文

首先, 我们使用 net.DialIP("ip4:icmp", nil, destAddr) 来创建一个 ICMP 报文.

接着我们使用下面的代码填充 ICMP 报文并发送:

binary.Write(&buffer, binary.BigEndian, icmp)

if _, err := conn.Write(buffer.Bytes()); err != nil {
return err
}

发送完之后, 我们使用下面的命令接收请求:

recv := make([]byte, 1024)
receiveCnt, err := conn.Read(recv)

同时我们还需要统计发送到接收之间所耗费的时间.

完整的代码如下所示:

func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr) error {
conn, err := net.DialIP("ip4:icmp", nil, destAddr)
if err != nil {
fmt.Printf("Fail to connect to remote host: %s\n", err)
return err
}
defer conn.Close()

var buffer bytes.Buffer
binary.Write(&buffer, binary.BigEndian, icmp)

if _, err := conn.Write(buffer.Bytes()); err != nil {
return err
}

tStart := time.Now()

conn.SetReadDeadline((time.Now().Add(time.Second * 2)))

recv := make([]byte, 1024)
receiveCnt, err := conn.Read(recv)

if err != nil {
return err
}

tEnd := time.Now()
duration := tEnd.Sub(tStart).Nanoseconds() / 1e6

fmt.Printf("%d bytes from %s: seq=%d time=%dms\n", receiveCnt, destAddr.String(), icmp.SequenceNum, duration)

return err
}

ping 命令的完整代码

Github 上的文件路径

package main

import (
"bytes"
"encoding/binary"
"fmt"
"net"
"os"
"time"
)

type ICMP struct {
Type uint8
Code uint8
CheckSum uint16
Identifier uint16
SequenceNum uint16
}

func usage() {
msg := `
Need to run as root!

Usage:
goping host

Example: ./goping www.baidu.com`

fmt.Println(msg)
os.Exit(0)
}

func getICMP(seq uint16) ICMP {
icmp := ICMP{
Type: 8,
Code: 0,
CheckSum: 0,
Identifier: 0,
SequenceNum: seq,
}

var buffer bytes.Buffer
binary.Write(&buffer, binary.BigEndian, icmp)
icmp.CheckSum = CheckSum(buffer.Bytes())
buffer.Reset()

return icmp
}

func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr) error {
conn, err := net.DialIP("ip4:icmp", nil, destAddr)
if err != nil {
fmt.Printf("Fail to connect to remote host: %s\n", err)
return err
}
defer conn.Close()

var buffer bytes.Buffer
binary.Write(&buffer, binary.BigEndian, icmp)

if _, err := conn.Write(buffer.Bytes()); err != nil {
return err
}

tStart := time.Now()

conn.SetReadDeadline((time.Now().Add(time.Second * 2)))

recv := make([]byte, 1024)
receiveCnt, err := conn.Read(recv)

if err != nil {
return err
}

tEnd := time.Now()
duration := tEnd.Sub(tStart).Nanoseconds() / 1e6

fmt.Printf("%d bytes from %s: seq=%d time=%dms\n", receiveCnt, destAddr.String(), icmp.SequenceNum, duration)

return err
}

func CheckSum(data []byte) uint16 {
var (
sum uint32
length int = len(data)
index int
)
for length > 1 {
sum += uint32(data[index])<<8 + uint32(data[index+1])
index += 2
length -= 2
}
if length > 0 {
sum += uint32(data[index])
}
sum += (sum >> 16)

return uint16(^sum)
}

func main() {
if len(os.Args) < 2 {
usage()
}

host := os.Args[1]
raddr, err := net.ResolveIPAddr("ip", host)
if err != nil {
fmt.Printf("Fail to resolve %s, %s\n", host, err)
return
}

fmt.Printf("Ping %s (%s):\n\n", raddr.String(), host)

for i := 1; i < 6; i++ {
if err = sendICMPRequest(getICMP(uint16(i)), raddr); err != nil {
fmt.Printf("Error: %s\n", err)
}
time.Sleep(2 * time.Second)
}
}

References