PageRenderTime 42ms CodeModel.GetById 15ms RepoModel.GetById 0ms app.codeStats 0ms

/third_party/gofrontend/libgo/go/net/rpc/server_test.go

http://github.com/axw/llgo
Go | 683 lines | 568 code | 92 blank | 23 comment | 153 complexity | 3800e872b7a7a86a7ebe9a8594bcc13d MD5 | raw file
Possible License(s): BSD-3-Clause, MIT
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package rpc
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http/httptest"
  12. "runtime"
  13. "strings"
  14. "sync"
  15. "sync/atomic"
  16. "testing"
  17. "time"
  18. )
  19. var (
  20. newServer *Server
  21. serverAddr, newServerAddr string
  22. httpServerAddr string
  23. once, newOnce, httpOnce sync.Once
  24. )
  25. const (
  26. newHttpPath = "/foo"
  27. )
  28. type Args struct {
  29. A, B int
  30. }
  31. type Reply struct {
  32. C int
  33. }
  34. type Arith int
  35. // Some of Arith's methods have value args, some have pointer args. That's deliberate.
  36. func (t *Arith) Add(args Args, reply *Reply) error {
  37. reply.C = args.A + args.B
  38. return nil
  39. }
  40. func (t *Arith) Mul(args *Args, reply *Reply) error {
  41. reply.C = args.A * args.B
  42. return nil
  43. }
  44. func (t *Arith) Div(args Args, reply *Reply) error {
  45. if args.B == 0 {
  46. return errors.New("divide by zero")
  47. }
  48. reply.C = args.A / args.B
  49. return nil
  50. }
  51. func (t *Arith) String(args *Args, reply *string) error {
  52. *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
  53. return nil
  54. }
  55. func (t *Arith) Scan(args string, reply *Reply) (err error) {
  56. _, err = fmt.Sscan(args, &reply.C)
  57. return
  58. }
  59. func (t *Arith) Error(args *Args, reply *Reply) error {
  60. panic("ERROR")
  61. }
  62. func listenTCP() (net.Listener, string) {
  63. l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
  64. if e != nil {
  65. log.Fatalf("net.Listen tcp :0: %v", e)
  66. }
  67. return l, l.Addr().String()
  68. }
  69. func startServer() {
  70. Register(new(Arith))
  71. RegisterName("net.rpc.Arith", new(Arith))
  72. var l net.Listener
  73. l, serverAddr = listenTCP()
  74. log.Println("Test RPC server listening on", serverAddr)
  75. go Accept(l)
  76. HandleHTTP()
  77. httpOnce.Do(startHttpServer)
  78. }
  79. func startNewServer() {
  80. newServer = NewServer()
  81. newServer.Register(new(Arith))
  82. newServer.RegisterName("net.rpc.Arith", new(Arith))
  83. newServer.RegisterName("newServer.Arith", new(Arith))
  84. var l net.Listener
  85. l, newServerAddr = listenTCP()
  86. log.Println("NewServer test RPC server listening on", newServerAddr)
  87. go newServer.Accept(l)
  88. newServer.HandleHTTP(newHttpPath, "/bar")
  89. httpOnce.Do(startHttpServer)
  90. }
  91. func startHttpServer() {
  92. server := httptest.NewServer(nil)
  93. httpServerAddr = server.Listener.Addr().String()
  94. log.Println("Test HTTP RPC server listening on", httpServerAddr)
  95. }
  96. func TestRPC(t *testing.T) {
  97. once.Do(startServer)
  98. testRPC(t, serverAddr)
  99. newOnce.Do(startNewServer)
  100. testRPC(t, newServerAddr)
  101. testNewServerRPC(t, newServerAddr)
  102. }
  103. func testRPC(t *testing.T, addr string) {
  104. client, err := Dial("tcp", addr)
  105. if err != nil {
  106. t.Fatal("dialing", err)
  107. }
  108. defer client.Close()
  109. // Synchronous calls
  110. args := &Args{7, 8}
  111. reply := new(Reply)
  112. err = client.Call("Arith.Add", args, reply)
  113. if err != nil {
  114. t.Errorf("Add: expected no error but got string %q", err.Error())
  115. }
  116. if reply.C != args.A+args.B {
  117. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  118. }
  119. // Nonexistent method
  120. args = &Args{7, 0}
  121. reply = new(Reply)
  122. err = client.Call("Arith.BadOperation", args, reply)
  123. // expect an error
  124. if err == nil {
  125. t.Error("BadOperation: expected error")
  126. } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
  127. t.Errorf("BadOperation: expected can't find method error; got %q", err)
  128. }
  129. // Unknown service
  130. args = &Args{7, 8}
  131. reply = new(Reply)
  132. err = client.Call("Arith.Unknown", args, reply)
  133. if err == nil {
  134. t.Error("expected error calling unknown service")
  135. } else if strings.Index(err.Error(), "method") < 0 {
  136. t.Error("expected error about method; got", err)
  137. }
  138. // Out of order.
  139. args = &Args{7, 8}
  140. mulReply := new(Reply)
  141. mulCall := client.Go("Arith.Mul", args, mulReply, nil)
  142. addReply := new(Reply)
  143. addCall := client.Go("Arith.Add", args, addReply, nil)
  144. addCall = <-addCall.Done
  145. if addCall.Error != nil {
  146. t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
  147. }
  148. if addReply.C != args.A+args.B {
  149. t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
  150. }
  151. mulCall = <-mulCall.Done
  152. if mulCall.Error != nil {
  153. t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
  154. }
  155. if mulReply.C != args.A*args.B {
  156. t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
  157. }
  158. // Error test
  159. args = &Args{7, 0}
  160. reply = new(Reply)
  161. err = client.Call("Arith.Div", args, reply)
  162. // expect an error: zero divide
  163. if err == nil {
  164. t.Error("Div: expected error")
  165. } else if err.Error() != "divide by zero" {
  166. t.Error("Div: expected divide by zero error; got", err)
  167. }
  168. // Bad type.
  169. reply = new(Reply)
  170. err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
  171. if err == nil {
  172. t.Error("expected error calling Arith.Add with wrong arg type")
  173. } else if strings.Index(err.Error(), "type") < 0 {
  174. t.Error("expected error about type; got", err)
  175. }
  176. // Non-struct argument
  177. const Val = 12345
  178. str := fmt.Sprint(Val)
  179. reply = new(Reply)
  180. err = client.Call("Arith.Scan", &str, reply)
  181. if err != nil {
  182. t.Errorf("Scan: expected no error but got string %q", err.Error())
  183. } else if reply.C != Val {
  184. t.Errorf("Scan: expected %d got %d", Val, reply.C)
  185. }
  186. // Non-struct reply
  187. args = &Args{27, 35}
  188. str = ""
  189. err = client.Call("Arith.String", args, &str)
  190. if err != nil {
  191. t.Errorf("String: expected no error but got string %q", err.Error())
  192. }
  193. expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
  194. if str != expect {
  195. t.Errorf("String: expected %s got %s", expect, str)
  196. }
  197. args = &Args{7, 8}
  198. reply = new(Reply)
  199. err = client.Call("Arith.Mul", args, reply)
  200. if err != nil {
  201. t.Errorf("Mul: expected no error but got string %q", err.Error())
  202. }
  203. if reply.C != args.A*args.B {
  204. t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
  205. }
  206. // ServiceName contain "." character
  207. args = &Args{7, 8}
  208. reply = new(Reply)
  209. err = client.Call("net.rpc.Arith.Add", args, reply)
  210. if err != nil {
  211. t.Errorf("Add: expected no error but got string %q", err.Error())
  212. }
  213. if reply.C != args.A+args.B {
  214. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  215. }
  216. }
  217. func testNewServerRPC(t *testing.T, addr string) {
  218. client, err := Dial("tcp", addr)
  219. if err != nil {
  220. t.Fatal("dialing", err)
  221. }
  222. defer client.Close()
  223. // Synchronous calls
  224. args := &Args{7, 8}
  225. reply := new(Reply)
  226. err = client.Call("newServer.Arith.Add", args, reply)
  227. if err != nil {
  228. t.Errorf("Add: expected no error but got string %q", err.Error())
  229. }
  230. if reply.C != args.A+args.B {
  231. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  232. }
  233. }
  234. func TestHTTP(t *testing.T) {
  235. once.Do(startServer)
  236. testHTTPRPC(t, "")
  237. newOnce.Do(startNewServer)
  238. testHTTPRPC(t, newHttpPath)
  239. }
  240. func testHTTPRPC(t *testing.T, path string) {
  241. var client *Client
  242. var err error
  243. if path == "" {
  244. client, err = DialHTTP("tcp", httpServerAddr)
  245. } else {
  246. client, err = DialHTTPPath("tcp", httpServerAddr, path)
  247. }
  248. if err != nil {
  249. t.Fatal("dialing", err)
  250. }
  251. defer client.Close()
  252. // Synchronous calls
  253. args := &Args{7, 8}
  254. reply := new(Reply)
  255. err = client.Call("Arith.Add", args, reply)
  256. if err != nil {
  257. t.Errorf("Add: expected no error but got string %q", err.Error())
  258. }
  259. if reply.C != args.A+args.B {
  260. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  261. }
  262. }
  263. // CodecEmulator provides a client-like api and a ServerCodec interface.
  264. // Can be used to test ServeRequest.
  265. type CodecEmulator struct {
  266. server *Server
  267. serviceMethod string
  268. args *Args
  269. reply *Reply
  270. err error
  271. }
  272. func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
  273. codec.serviceMethod = serviceMethod
  274. codec.args = args
  275. codec.reply = reply
  276. codec.err = nil
  277. var serverError error
  278. if codec.server == nil {
  279. serverError = ServeRequest(codec)
  280. } else {
  281. serverError = codec.server.ServeRequest(codec)
  282. }
  283. if codec.err == nil && serverError != nil {
  284. codec.err = serverError
  285. }
  286. return codec.err
  287. }
  288. func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
  289. req.ServiceMethod = codec.serviceMethod
  290. req.Seq = 0
  291. return nil
  292. }
  293. func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
  294. if codec.args == nil {
  295. return io.ErrUnexpectedEOF
  296. }
  297. *(argv.(*Args)) = *codec.args
  298. return nil
  299. }
  300. func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error {
  301. if resp.Error != "" {
  302. codec.err = errors.New(resp.Error)
  303. } else {
  304. *codec.reply = *(reply.(*Reply))
  305. }
  306. return nil
  307. }
  308. func (codec *CodecEmulator) Close() error {
  309. return nil
  310. }
  311. func TestServeRequest(t *testing.T) {
  312. once.Do(startServer)
  313. testServeRequest(t, nil)
  314. newOnce.Do(startNewServer)
  315. testServeRequest(t, newServer)
  316. }
  317. func testServeRequest(t *testing.T, server *Server) {
  318. client := CodecEmulator{server: server}
  319. defer client.Close()
  320. args := &Args{7, 8}
  321. reply := new(Reply)
  322. err := client.Call("Arith.Add", args, reply)
  323. if err != nil {
  324. t.Errorf("Add: expected no error but got string %q", err.Error())
  325. }
  326. if reply.C != args.A+args.B {
  327. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  328. }
  329. err = client.Call("Arith.Add", nil, reply)
  330. if err == nil {
  331. t.Errorf("expected error calling Arith.Add with nil arg")
  332. }
  333. }
  334. type ReplyNotPointer int
  335. type ArgNotPublic int
  336. type ReplyNotPublic int
  337. type NeedsPtrType int
  338. type local struct{}
  339. func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
  340. return nil
  341. }
  342. func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
  343. return nil
  344. }
  345. func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
  346. return nil
  347. }
  348. func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
  349. return nil
  350. }
  351. // Check that registration handles lots of bad methods and a type with no suitable methods.
  352. func TestRegistrationError(t *testing.T) {
  353. err := Register(new(ReplyNotPointer))
  354. if err == nil {
  355. t.Error("expected error registering ReplyNotPointer")
  356. }
  357. err = Register(new(ArgNotPublic))
  358. if err == nil {
  359. t.Error("expected error registering ArgNotPublic")
  360. }
  361. err = Register(new(ReplyNotPublic))
  362. if err == nil {
  363. t.Error("expected error registering ReplyNotPublic")
  364. }
  365. err = Register(NeedsPtrType(0))
  366. if err == nil {
  367. t.Error("expected error registering NeedsPtrType")
  368. } else if !strings.Contains(err.Error(), "pointer") {
  369. t.Error("expected hint when registering NeedsPtrType")
  370. }
  371. }
  372. type WriteFailCodec int
  373. func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
  374. // the panic caused by this error used to not unlock a lock.
  375. return errors.New("fail")
  376. }
  377. func (WriteFailCodec) ReadResponseHeader(*Response) error {
  378. select {}
  379. }
  380. func (WriteFailCodec) ReadResponseBody(interface{}) error {
  381. select {}
  382. }
  383. func (WriteFailCodec) Close() error {
  384. return nil
  385. }
  386. func TestSendDeadlock(t *testing.T) {
  387. client := NewClientWithCodec(WriteFailCodec(0))
  388. defer client.Close()
  389. done := make(chan bool)
  390. go func() {
  391. testSendDeadlock(client)
  392. testSendDeadlock(client)
  393. done <- true
  394. }()
  395. select {
  396. case <-done:
  397. return
  398. case <-time.After(5 * time.Second):
  399. t.Fatal("deadlock")
  400. }
  401. }
  402. func testSendDeadlock(client *Client) {
  403. defer func() {
  404. recover()
  405. }()
  406. args := &Args{7, 8}
  407. reply := new(Reply)
  408. client.Call("Arith.Add", args, reply)
  409. }
  410. func dialDirect() (*Client, error) {
  411. return Dial("tcp", serverAddr)
  412. }
  413. func dialHTTP() (*Client, error) {
  414. return DialHTTP("tcp", httpServerAddr)
  415. }
  416. func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
  417. once.Do(startServer)
  418. client, err := dial()
  419. if err != nil {
  420. t.Fatal("error dialing", err)
  421. }
  422. defer client.Close()
  423. args := &Args{7, 8}
  424. reply := new(Reply)
  425. return testing.AllocsPerRun(100, func() {
  426. err := client.Call("Arith.Add", args, reply)
  427. if err != nil {
  428. t.Errorf("Add: expected no error but got string %q", err.Error())
  429. }
  430. if reply.C != args.A+args.B {
  431. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  432. }
  433. })
  434. }
  435. func TestCountMallocs(t *testing.T) {
  436. if testing.Short() {
  437. t.Skip("skipping malloc count in short mode")
  438. }
  439. if runtime.GOMAXPROCS(0) > 1 {
  440. t.Skip("skipping; GOMAXPROCS>1")
  441. }
  442. fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
  443. }
  444. func TestCountMallocsOverHTTP(t *testing.T) {
  445. if testing.Short() {
  446. t.Skip("skipping malloc count in short mode")
  447. }
  448. if runtime.GOMAXPROCS(0) > 1 {
  449. t.Skip("skipping; GOMAXPROCS>1")
  450. }
  451. fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
  452. }
  453. type writeCrasher struct {
  454. done chan bool
  455. }
  456. func (writeCrasher) Close() error {
  457. return nil
  458. }
  459. func (w *writeCrasher) Read(p []byte) (int, error) {
  460. <-w.done
  461. return 0, io.EOF
  462. }
  463. func (writeCrasher) Write(p []byte) (int, error) {
  464. return 0, errors.New("fake write failure")
  465. }
  466. func TestClientWriteError(t *testing.T) {
  467. w := &writeCrasher{done: make(chan bool)}
  468. c := NewClient(w)
  469. defer c.Close()
  470. res := false
  471. err := c.Call("foo", 1, &res)
  472. if err == nil {
  473. t.Fatal("expected error")
  474. }
  475. if err.Error() != "fake write failure" {
  476. t.Error("unexpected value of error:", err)
  477. }
  478. w.done <- true
  479. }
  480. func TestTCPClose(t *testing.T) {
  481. once.Do(startServer)
  482. client, err := dialHTTP()
  483. if err != nil {
  484. t.Fatalf("dialing: %v", err)
  485. }
  486. defer client.Close()
  487. args := Args{17, 8}
  488. var reply Reply
  489. err = client.Call("Arith.Mul", args, &reply)
  490. if err != nil {
  491. t.Fatal("arith error:", err)
  492. }
  493. t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
  494. if reply.C != args.A*args.B {
  495. t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
  496. }
  497. }
  498. func TestErrorAfterClientClose(t *testing.T) {
  499. once.Do(startServer)
  500. client, err := dialHTTP()
  501. if err != nil {
  502. t.Fatalf("dialing: %v", err)
  503. }
  504. err = client.Close()
  505. if err != nil {
  506. t.Fatal("close error:", err)
  507. }
  508. err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
  509. if err != ErrShutdown {
  510. t.Errorf("Forever: expected ErrShutdown got %v", err)
  511. }
  512. }
  513. func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
  514. once.Do(startServer)
  515. client, err := dial()
  516. if err != nil {
  517. b.Fatal("error dialing:", err)
  518. }
  519. defer client.Close()
  520. // Synchronous calls
  521. args := &Args{7, 8}
  522. b.ResetTimer()
  523. b.RunParallel(func(pb *testing.PB) {
  524. reply := new(Reply)
  525. for pb.Next() {
  526. err := client.Call("Arith.Add", args, reply)
  527. if err != nil {
  528. b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
  529. }
  530. if reply.C != args.A+args.B {
  531. b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
  532. }
  533. }
  534. })
  535. }
  536. func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
  537. const MaxConcurrentCalls = 100
  538. once.Do(startServer)
  539. client, err := dial()
  540. if err != nil {
  541. b.Fatal("error dialing:", err)
  542. }
  543. defer client.Close()
  544. // Asynchronous calls
  545. args := &Args{7, 8}
  546. procs := 4 * runtime.GOMAXPROCS(-1)
  547. send := int32(b.N)
  548. recv := int32(b.N)
  549. var wg sync.WaitGroup
  550. wg.Add(procs)
  551. gate := make(chan bool, MaxConcurrentCalls)
  552. res := make(chan *Call, MaxConcurrentCalls)
  553. b.ResetTimer()
  554. for p := 0; p < procs; p++ {
  555. go func() {
  556. for atomic.AddInt32(&send, -1) >= 0 {
  557. gate <- true
  558. reply := new(Reply)
  559. client.Go("Arith.Add", args, reply, res)
  560. }
  561. }()
  562. go func() {
  563. for call := range res {
  564. A := call.Args.(*Args).A
  565. B := call.Args.(*Args).B
  566. C := call.Reply.(*Reply).C
  567. if A+B != C {
  568. b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C)
  569. }
  570. <-gate
  571. if atomic.AddInt32(&recv, -1) == 0 {
  572. close(res)
  573. }
  574. }
  575. wg.Done()
  576. }()
  577. }
  578. wg.Wait()
  579. }
  580. func BenchmarkEndToEnd(b *testing.B) {
  581. benchmarkEndToEnd(dialDirect, b)
  582. }
  583. func BenchmarkEndToEndHTTP(b *testing.B) {
  584. benchmarkEndToEnd(dialHTTP, b)
  585. }
  586. func BenchmarkEndToEndAsync(b *testing.B) {
  587. benchmarkEndToEndAsync(dialDirect, b)
  588. }
  589. func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
  590. benchmarkEndToEndAsync(dialHTTP, b)
  591. }