Source file
src/os/readfrom_linux_test.go
1
2
3
4
5 package os_test
6
7 import (
8 "bytes"
9 "errors"
10 "internal/poll"
11 "internal/testpty"
12 "io"
13 "math/rand"
14 "net"
15 . "os"
16 "path/filepath"
17 "strconv"
18 "sync"
19 "syscall"
20 "testing"
21 "time"
22 )
23
24 func TestSpliceFile(t *testing.T) {
25 sizes := []int{
26 1,
27 42,
28 1025,
29 syscall.Getpagesize() + 1,
30 32769,
31 }
32 t.Run("Basic-TCP", func(t *testing.T) {
33 for _, size := range sizes {
34 t.Run(strconv.Itoa(size), func(t *testing.T) {
35 testSpliceFile(t, "tcp", int64(size), -1)
36 })
37 }
38 })
39 t.Run("Basic-Unix", func(t *testing.T) {
40 for _, size := range sizes {
41 t.Run(strconv.Itoa(size), func(t *testing.T) {
42 testSpliceFile(t, "unix", int64(size), -1)
43 })
44 }
45 })
46 t.Run("TCP-To-TTY", func(t *testing.T) {
47 testSpliceToTTY(t, "tcp", 32768)
48 })
49 t.Run("Unix-To-TTY", func(t *testing.T) {
50 testSpliceToTTY(t, "unix", 32768)
51 })
52 t.Run("Limited", func(t *testing.T) {
53 t.Run("OneLess-TCP", func(t *testing.T) {
54 for _, size := range sizes {
55 t.Run(strconv.Itoa(size), func(t *testing.T) {
56 testSpliceFile(t, "tcp", int64(size), int64(size)-1)
57 })
58 }
59 })
60 t.Run("OneLess-Unix", func(t *testing.T) {
61 for _, size := range sizes {
62 t.Run(strconv.Itoa(size), func(t *testing.T) {
63 testSpliceFile(t, "unix", int64(size), int64(size)-1)
64 })
65 }
66 })
67 t.Run("Half-TCP", func(t *testing.T) {
68 for _, size := range sizes {
69 t.Run(strconv.Itoa(size), func(t *testing.T) {
70 testSpliceFile(t, "tcp", int64(size), int64(size)/2)
71 })
72 }
73 })
74 t.Run("Half-Unix", func(t *testing.T) {
75 for _, size := range sizes {
76 t.Run(strconv.Itoa(size), func(t *testing.T) {
77 testSpliceFile(t, "unix", int64(size), int64(size)/2)
78 })
79 }
80 })
81 t.Run("More-TCP", func(t *testing.T) {
82 for _, size := range sizes {
83 t.Run(strconv.Itoa(size), func(t *testing.T) {
84 testSpliceFile(t, "tcp", int64(size), int64(size)+1)
85 })
86 }
87 })
88 t.Run("More-Unix", func(t *testing.T) {
89 for _, size := range sizes {
90 t.Run(strconv.Itoa(size), func(t *testing.T) {
91 testSpliceFile(t, "unix", int64(size), int64(size)+1)
92 })
93 }
94 })
95 })
96 }
97
98 func testSpliceFile(t *testing.T, proto string, size, limit int64) {
99 dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size)
100 defer cleanup()
101
102
103 var (
104 r io.Reader
105 lr *io.LimitedReader
106 )
107 if limit >= 0 {
108 lr = &io.LimitedReader{N: limit, R: src}
109 r = lr
110 if limit < int64(len(data)) {
111 data = data[:limit]
112 }
113 } else {
114 r = src
115 }
116
117 n, err := io.Copy(dst, r)
118 if err != nil {
119 t.Fatal(err)
120 }
121
122
123 if n > 0 && !hook.called {
124 t.Fatal("expected to called poll.Splice")
125 }
126 if hook.called && hook.dstfd != int(dst.Fd()) {
127 t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
128 }
129 sc, ok := src.(syscall.Conn)
130 if !ok {
131 t.Fatalf("server Conn is not a syscall.Conn")
132 }
133 rc, err := sc.SyscallConn()
134 if err != nil {
135 t.Fatalf("server Conn SyscallConn error: %v", err)
136 }
137 if err = rc.Control(func(fd uintptr) {
138 if hook.called && hook.srcfd != int(fd) {
139 t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd))
140 }
141 }); err != nil {
142 t.Fatalf("server Conn Control error: %v", err)
143 }
144
145
146
147
148 dstoff, err := dst.Seek(0, io.SeekCurrent)
149 if err != nil {
150 t.Fatal(err)
151 }
152 if dstoff != int64(len(data)) {
153 t.Errorf("dstoff = %d, want %d", dstoff, len(data))
154 }
155 if n != int64(len(data)) {
156 t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
157 }
158 mustSeekStart(t, dst)
159 mustContainData(t, dst, data)
160
161
162 if lr != nil {
163 if want := limit - n; lr.N != want {
164 t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
165 }
166 }
167 }
168
169
170 func testSpliceToTTY(t *testing.T, proto string, size int64) {
171 var wg sync.WaitGroup
172
173
174
175
176 defer wg.Wait()
177
178 pty, ttyName, err := testpty.Open()
179 if err != nil {
180 t.Skipf("skipping test because pty open failed: %v", err)
181 }
182 defer pty.Close()
183
184
185
186
187 ttyFD, err := syscall.Open(ttyName, syscall.O_RDWR, 0)
188 if err != nil {
189 t.Skipf("skipping test because failed to open tty: %v", err)
190 }
191 defer syscall.Close(ttyFD)
192
193 tty := NewFile(uintptr(ttyFD), "tty")
194 defer tty.Close()
195
196 client, server := createSocketPair(t, proto)
197
198 data := bytes.Repeat([]byte{'a'}, int(size))
199
200 wg.Add(1)
201 go func() {
202 defer wg.Done()
203
204
205
206 for i := 0; i < len(data); i += 1024 {
207 if _, err := client.Write(data[i : i+1024]); err != nil {
208
209
210 if !errors.Is(err, net.ErrClosed) {
211 t.Errorf("error writing to socket: %v", err)
212 }
213 return
214 }
215 }
216 client.Close()
217 }()
218
219 wg.Add(1)
220 go func() {
221 defer wg.Done()
222 buf := make([]byte, 32)
223 for {
224 if _, err := pty.Read(buf); err != nil {
225 if err != io.EOF && !errors.Is(err, ErrClosed) {
226
227
228 t.Logf("error reading from pty: %v", err)
229 }
230 return
231 }
232 }
233 }()
234
235
236 defer client.Close()
237
238 _, err = io.Copy(tty, server)
239 if err != nil {
240 t.Fatal(err)
241 }
242 }
243
244 var (
245 copyFileTests = []copyFileTestFunc{newCopyFileRangeTest}
246 copyFileHooks = []copyFileTestHook{hookCopyFileRange}
247 )
248
249 func testCopyFiles(t *testing.T, size, limit int64) {
250 testCopyFileRange(t, size, limit)
251 }
252
253 func testCopyFileRange(t *testing.T, size int64, limit int64) {
254 dst, src, data, hook, name := newCopyFileRangeTest(t, size)
255 testCopyFile(t, dst, src, data, hook, limit, name)
256 }
257
258
259
260
261
262 func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileHook, name string) {
263 t.Helper()
264
265 name = "newCopyFileRangeTest"
266
267 dst, src, data = newCopyFileTest(t, size)
268 hook, _ = hookCopyFileRange(t)
269
270 return
271 }
272
273
274
275
276
277
278 func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) {
279 t.Helper()
280
281 hook := hookSpliceFile(t)
282
283 client, server := createSocketPair(t, proto)
284
285 dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test")
286 if err != nil {
287 t.Fatal(err)
288 }
289 t.Cleanup(func() { dst.Close() })
290
291 randSeed := time.Now().Unix()
292 t.Logf("random data seed: %d\n", randSeed)
293 prng := rand.New(rand.NewSource(randSeed))
294 data := make([]byte, size)
295 prng.Read(data)
296
297 done := make(chan struct{})
298 go func() {
299 client.Write(data)
300 client.Close()
301 close(done)
302 }()
303
304 return dst, server, data, hook, func() { <-done }
305 }
306
307 func hookCopyFileRange(t *testing.T) (hook *copyFileHook, name string) {
308 name = "hookCopyFileRange"
309
310 hook = new(copyFileHook)
311 orig := *PollCopyFileRangeP
312 t.Cleanup(func() {
313 *PollCopyFileRangeP = orig
314 })
315 *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
316 hook.called = true
317 hook.dstfd = dst.Sysfd
318 hook.srcfd = src.Sysfd
319 hook.written, hook.handled, hook.err = orig(dst, src, remain)
320 return hook.written, hook.handled, hook.err
321 }
322 return
323 }
324
325 func hookSpliceFile(t *testing.T) *spliceFileHook {
326 h := new(spliceFileHook)
327 h.install()
328 t.Cleanup(h.uninstall)
329 return h
330 }
331
332 type spliceFileHook struct {
333 called bool
334 dstfd int
335 srcfd int
336 remain int64
337
338 written int64
339 handled bool
340 err error
341
342 original func(dst, src *poll.FD, remain int64) (int64, bool, error)
343 }
344
345 func (h *spliceFileHook) install() {
346 h.original = *PollSpliceFile
347 *PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
348 h.called = true
349 h.dstfd = dst.Sysfd
350 h.srcfd = src.Sysfd
351 h.remain = remain
352 h.written, h.handled, h.err = h.original(dst, src, remain)
353 return h.written, h.handled, h.err
354 }
355 }
356
357 func (h *spliceFileHook) uninstall() {
358 *PollSpliceFile = h.original
359 }
360
361
362 func TestProcCopy(t *testing.T) {
363 t.Parallel()
364
365 const cmdlineFile = "/proc/self/cmdline"
366 cmdline, err := ReadFile(cmdlineFile)
367 if err != nil {
368 t.Skipf("can't read /proc file: %v", err)
369 }
370 in, err := Open(cmdlineFile)
371 if err != nil {
372 t.Fatal(err)
373 }
374 defer in.Close()
375 outFile := filepath.Join(t.TempDir(), "cmdline")
376 out, err := Create(outFile)
377 if err != nil {
378 t.Fatal(err)
379 }
380 if _, err := io.Copy(out, in); err != nil {
381 t.Fatal(err)
382 }
383 if err := out.Close(); err != nil {
384 t.Fatal(err)
385 }
386 copy, err := ReadFile(outFile)
387 if err != nil {
388 t.Fatal(err)
389 }
390 if !bytes.Equal(cmdline, copy) {
391 t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
392 }
393 }
394
395 func TestGetPollFDAndNetwork(t *testing.T) {
396 t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
397 t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
398 }
399
400 func testGetPollFDAndNetwork(t *testing.T, proto string) {
401 _, server := createSocketPair(t, proto)
402 sc, ok := server.(syscall.Conn)
403 if !ok {
404 t.Fatalf("server Conn is not a syscall.Conn")
405 }
406 rc, err := sc.SyscallConn()
407 if err != nil {
408 t.Fatalf("server SyscallConn error: %v", err)
409 }
410 if err = rc.Control(func(fd uintptr) {
411 pfd, network := GetPollFDAndNetwork(server)
412 if pfd == nil {
413 t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
414 }
415 if string(network) != proto {
416 t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
417 }
418 if pfd.Sysfd != int(fd) {
419 t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
420 }
421 if !pfd.IsStream {
422 t.Fatalf("expected IsStream to be true")
423 }
424 if err = pfd.Init(proto, true); err == nil {
425 t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error")
426 }
427 }); err != nil {
428 t.Fatalf("server Control error: %v", err)
429 }
430 }
431
View as plain text