/vendor/google.golang.org/grpc/grpclb/grpclb_test.go

https://github.com/Bytom/bytom · Go · 901 lines · 782 code · 64 blank · 55 comment · 193 complexity · 499294f5063a3c9fa236e96b09eb6410 MD5 · raw file

  1. /*
  2. *
  3. * Copyright 2016 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. //go:generate protoc --go_out=plugins=:$GOPATH grpc_lb_v1/messages/messages.proto
  19. //go:generate protoc --go_out=plugins=grpc:$GOPATH grpc_lb_v1/service/service.proto
  20. // Package grpclb_test is currently used only for grpclb testing.
  21. package grpclb_test
  22. import (
  23. "errors"
  24. "fmt"
  25. "io"
  26. "net"
  27. "strings"
  28. "sync"
  29. "testing"
  30. "time"
  31. "github.com/golang/protobuf/proto"
  32. "golang.org/x/net/context"
  33. "google.golang.org/grpc"
  34. "google.golang.org/grpc/codes"
  35. "google.golang.org/grpc/credentials"
  36. lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
  37. lbspb "google.golang.org/grpc/grpclb/grpc_lb_v1/service"
  38. _ "google.golang.org/grpc/grpclog/glogger"
  39. "google.golang.org/grpc/metadata"
  40. "google.golang.org/grpc/naming"
  41. testpb "google.golang.org/grpc/test/grpc_testing"
  42. "google.golang.org/grpc/test/leakcheck"
  43. )
  44. var (
  45. lbsn = "bar.com"
  46. besn = "foo.com"
  47. lbToken = "iamatoken"
  48. // Resolver replaces localhost with fakeName in Next().
  49. // Dialer replaces fakeName with localhost when dialing.
  50. // This will test that custom dialer is passed from Dial to grpclb.
  51. fakeName = "fake.Name"
  52. )
  53. type testWatcher struct {
  54. // the channel to receives name resolution updates
  55. update chan *naming.Update
  56. // the side channel to get to know how many updates in a batch
  57. side chan int
  58. // the channel to notifiy update injector that the update reading is done
  59. readDone chan int
  60. }
  61. func (w *testWatcher) Next() (updates []*naming.Update, err error) {
  62. n, ok := <-w.side
  63. if !ok {
  64. return nil, fmt.Errorf("w.side is closed")
  65. }
  66. for i := 0; i < n; i++ {
  67. u, ok := <-w.update
  68. if !ok {
  69. break
  70. }
  71. if u != nil {
  72. // Resolver replaces localhost with fakeName in Next().
  73. // Custom dialer will replace fakeName with localhost when dialing.
  74. u.Addr = strings.Replace(u.Addr, "localhost", fakeName, 1)
  75. updates = append(updates, u)
  76. }
  77. }
  78. w.readDone <- 0
  79. return
  80. }
  81. func (w *testWatcher) Close() {
  82. close(w.side)
  83. }
  84. // Inject naming resolution updates to the testWatcher.
  85. func (w *testWatcher) inject(updates []*naming.Update) {
  86. w.side <- len(updates)
  87. for _, u := range updates {
  88. w.update <- u
  89. }
  90. <-w.readDone
  91. }
  92. type testNameResolver struct {
  93. w *testWatcher
  94. addrs []string
  95. }
  96. func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
  97. r.w = &testWatcher{
  98. update: make(chan *naming.Update, len(r.addrs)),
  99. side: make(chan int, 1),
  100. readDone: make(chan int),
  101. }
  102. r.w.side <- len(r.addrs)
  103. for _, addr := range r.addrs {
  104. r.w.update <- &naming.Update{
  105. Op: naming.Add,
  106. Addr: addr,
  107. Metadata: &naming.AddrMetadataGRPCLB{
  108. AddrType: naming.GRPCLB,
  109. ServerName: lbsn,
  110. },
  111. }
  112. }
  113. go func() {
  114. <-r.w.readDone
  115. }()
  116. return r.w, nil
  117. }
  118. func (r *testNameResolver) inject(updates []*naming.Update) {
  119. if r.w != nil {
  120. r.w.inject(updates)
  121. }
  122. }
  123. type serverNameCheckCreds struct {
  124. mu sync.Mutex
  125. sn string
  126. expected string
  127. }
  128. func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  129. if _, err := io.WriteString(rawConn, c.sn); err != nil {
  130. fmt.Printf("Failed to write the server name %s to the client %v", c.sn, err)
  131. return nil, nil, err
  132. }
  133. return rawConn, nil, nil
  134. }
  135. func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  136. c.mu.Lock()
  137. defer c.mu.Unlock()
  138. b := make([]byte, len(c.expected))
  139. errCh := make(chan error, 1)
  140. go func() {
  141. _, err := rawConn.Read(b)
  142. errCh <- err
  143. }()
  144. select {
  145. case err := <-errCh:
  146. if err != nil {
  147. fmt.Printf("Failed to read the server name from the server %v", err)
  148. return nil, nil, err
  149. }
  150. case <-ctx.Done():
  151. return nil, nil, ctx.Err()
  152. }
  153. if c.expected != string(b) {
  154. fmt.Printf("Read the server name %s want %s", string(b), c.expected)
  155. return nil, nil, errors.New("received unexpected server name")
  156. }
  157. return rawConn, nil, nil
  158. }
  159. func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
  160. c.mu.Lock()
  161. defer c.mu.Unlock()
  162. return credentials.ProtocolInfo{}
  163. }
  164. func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
  165. c.mu.Lock()
  166. defer c.mu.Unlock()
  167. return &serverNameCheckCreds{
  168. expected: c.expected,
  169. }
  170. }
  171. func (c *serverNameCheckCreds) OverrideServerName(s string) error {
  172. c.mu.Lock()
  173. defer c.mu.Unlock()
  174. c.expected = s
  175. return nil
  176. }
  177. // fakeNameDialer replaces fakeName with localhost when dialing.
  178. // This will test that custom dialer is passed from Dial to grpclb.
  179. func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) {
  180. addr = strings.Replace(addr, fakeName, "localhost", 1)
  181. return net.DialTimeout("tcp", addr, timeout)
  182. }
  183. type remoteBalancer struct {
  184. sls []*lbmpb.ServerList
  185. intervals []time.Duration
  186. statsDura time.Duration
  187. done chan struct{}
  188. mu sync.Mutex
  189. stats lbmpb.ClientStats
  190. }
  191. func newRemoteBalancer(sls []*lbmpb.ServerList, intervals []time.Duration) *remoteBalancer {
  192. return &remoteBalancer{
  193. sls: sls,
  194. intervals: intervals,
  195. done: make(chan struct{}),
  196. }
  197. }
  198. func (b *remoteBalancer) stop() {
  199. close(b.done)
  200. }
  201. func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer) error {
  202. req, err := stream.Recv()
  203. if err != nil {
  204. return err
  205. }
  206. initReq := req.GetInitialRequest()
  207. if initReq.Name != besn {
  208. return grpc.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
  209. }
  210. resp := &lbmpb.LoadBalanceResponse{
  211. LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_InitialResponse{
  212. InitialResponse: &lbmpb.InitialLoadBalanceResponse{
  213. ClientStatsReportInterval: &lbmpb.Duration{
  214. Seconds: int64(b.statsDura.Seconds()),
  215. Nanos: int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
  216. },
  217. },
  218. },
  219. }
  220. if err := stream.Send(resp); err != nil {
  221. return err
  222. }
  223. go func() {
  224. for {
  225. var (
  226. req *lbmpb.LoadBalanceRequest
  227. err error
  228. )
  229. if req, err = stream.Recv(); err != nil {
  230. return
  231. }
  232. b.mu.Lock()
  233. b.stats.NumCallsStarted += req.GetClientStats().NumCallsStarted
  234. b.stats.NumCallsFinished += req.GetClientStats().NumCallsFinished
  235. b.stats.NumCallsFinishedWithDropForRateLimiting += req.GetClientStats().NumCallsFinishedWithDropForRateLimiting
  236. b.stats.NumCallsFinishedWithDropForLoadBalancing += req.GetClientStats().NumCallsFinishedWithDropForLoadBalancing
  237. b.stats.NumCallsFinishedWithClientFailedToSend += req.GetClientStats().NumCallsFinishedWithClientFailedToSend
  238. b.stats.NumCallsFinishedKnownReceived += req.GetClientStats().NumCallsFinishedKnownReceived
  239. b.mu.Unlock()
  240. }
  241. }()
  242. for k, v := range b.sls {
  243. time.Sleep(b.intervals[k])
  244. resp = &lbmpb.LoadBalanceResponse{
  245. LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_ServerList{
  246. ServerList: v,
  247. },
  248. }
  249. if err := stream.Send(resp); err != nil {
  250. return err
  251. }
  252. }
  253. <-b.done
  254. return nil
  255. }
  256. type testServer struct {
  257. testpb.TestServiceServer
  258. addr string
  259. }
  260. const testmdkey = "testmd"
  261. func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
  262. md, ok := metadata.FromIncomingContext(ctx)
  263. if !ok {
  264. return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
  265. }
  266. if md == nil || md["lb-token"][0] != lbToken {
  267. return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
  268. }
  269. grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
  270. return &testpb.Empty{}, nil
  271. }
  272. func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
  273. return nil
  274. }
  275. func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) {
  276. for _, l := range lis {
  277. creds := &serverNameCheckCreds{
  278. sn: sn,
  279. }
  280. s := grpc.NewServer(grpc.Creds(creds))
  281. testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String()})
  282. servers = append(servers, s)
  283. go func(s *grpc.Server, l net.Listener) {
  284. s.Serve(l)
  285. }(s, l)
  286. }
  287. return
  288. }
  289. func stopBackends(servers []*grpc.Server) {
  290. for _, s := range servers {
  291. s.Stop()
  292. }
  293. }
  294. type testServers struct {
  295. lbAddr string
  296. ls *remoteBalancer
  297. lb *grpc.Server
  298. beIPs []net.IP
  299. bePorts []int
  300. }
  301. func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) {
  302. var (
  303. beListeners []net.Listener
  304. ls *remoteBalancer
  305. lb *grpc.Server
  306. beIPs []net.IP
  307. bePorts []int
  308. )
  309. for i := 0; i < numberOfBackends; i++ {
  310. // Start a backend.
  311. beLis, e := net.Listen("tcp", "localhost:0")
  312. if e != nil {
  313. err = fmt.Errorf("Failed to listen %v", err)
  314. return
  315. }
  316. beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
  317. bePorts = append(bePorts, beLis.Addr().(*net.TCPAddr).Port)
  318. beListeners = append(beListeners, beLis)
  319. }
  320. backends := startBackends(besn, beListeners...)
  321. // Start a load balancer.
  322. lbLis, err := net.Listen("tcp", "localhost:0")
  323. if err != nil {
  324. err = fmt.Errorf("Failed to create the listener for the load balancer %v", err)
  325. return
  326. }
  327. lbCreds := &serverNameCheckCreds{
  328. sn: lbsn,
  329. }
  330. lb = grpc.NewServer(grpc.Creds(lbCreds))
  331. if err != nil {
  332. err = fmt.Errorf("Failed to generate the port number %v", err)
  333. return
  334. }
  335. ls = newRemoteBalancer(nil, nil)
  336. lbspb.RegisterLoadBalancerServer(lb, ls)
  337. go func() {
  338. lb.Serve(lbLis)
  339. }()
  340. tss = &testServers{
  341. lbAddr: lbLis.Addr().String(),
  342. ls: ls,
  343. lb: lb,
  344. beIPs: beIPs,
  345. bePorts: bePorts,
  346. }
  347. cleanup = func() {
  348. defer stopBackends(backends)
  349. defer func() {
  350. ls.stop()
  351. lb.Stop()
  352. }()
  353. }
  354. return
  355. }
  356. func TestGRPCLB(t *testing.T) {
  357. defer leakcheck.Check(t)
  358. tss, cleanup, err := newLoadBalancer(1)
  359. if err != nil {
  360. t.Fatalf("failed to create new load balancer: %v", err)
  361. }
  362. defer cleanup()
  363. be := &lbmpb.Server{
  364. IpAddress: tss.beIPs[0],
  365. Port: int32(tss.bePorts[0]),
  366. LoadBalanceToken: lbToken,
  367. }
  368. var bes []*lbmpb.Server
  369. bes = append(bes, be)
  370. sl := &lbmpb.ServerList{
  371. Servers: bes,
  372. }
  373. tss.ls.sls = []*lbmpb.ServerList{sl}
  374. tss.ls.intervals = []time.Duration{0}
  375. creds := serverNameCheckCreds{
  376. expected: besn,
  377. }
  378. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  379. defer cancel()
  380. cc, err := grpc.DialContext(ctx, besn,
  381. grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})),
  382. grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
  383. if err != nil {
  384. t.Fatalf("Failed to dial to the backend %v", err)
  385. }
  386. defer cc.Close()
  387. testC := testpb.NewTestServiceClient(cc)
  388. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
  389. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  390. }
  391. }
  392. func TestDropRequest(t *testing.T) {
  393. defer leakcheck.Check(t)
  394. tss, cleanup, err := newLoadBalancer(2)
  395. if err != nil {
  396. t.Fatalf("failed to create new load balancer: %v", err)
  397. }
  398. defer cleanup()
  399. tss.ls.sls = []*lbmpb.ServerList{{
  400. Servers: []*lbmpb.Server{{
  401. IpAddress: tss.beIPs[0],
  402. Port: int32(tss.bePorts[0]),
  403. LoadBalanceToken: lbToken,
  404. DropForLoadBalancing: true,
  405. }, {
  406. IpAddress: tss.beIPs[1],
  407. Port: int32(tss.bePorts[1]),
  408. LoadBalanceToken: lbToken,
  409. DropForLoadBalancing: false,
  410. }},
  411. }}
  412. tss.ls.intervals = []time.Duration{0}
  413. creds := serverNameCheckCreds{
  414. expected: besn,
  415. }
  416. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  417. defer cancel()
  418. cc, err := grpc.DialContext(ctx, besn,
  419. grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})),
  420. grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
  421. if err != nil {
  422. t.Fatalf("Failed to dial to the backend %v", err)
  423. }
  424. defer cc.Close()
  425. testC := testpb.NewTestServiceClient(cc)
  426. // Wait until the first connection is up.
  427. // The first one has Drop set to true, error should contain "drop requests".
  428. for {
  429. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
  430. if strings.Contains(err.Error(), "drops requests") {
  431. break
  432. }
  433. }
  434. }
  435. // The 1st, non-fail-fast RPC should succeed. This ensures both server
  436. // connections are made, because the first one has DropForLoadBalancing set to true.
  437. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
  438. t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
  439. }
  440. for i := 0; i < 3; i++ {
  441. // Odd fail-fast RPCs should fail, because the 1st backend has DropForLoadBalancing
  442. // set to true.
  443. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
  444. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
  445. }
  446. // Even fail-fast RPCs should succeed since they choose the
  447. // non-drop-request backend according to the round robin policy.
  448. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
  449. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  450. }
  451. }
  452. }
  453. func TestDropRequestFailedNonFailFast(t *testing.T) {
  454. defer leakcheck.Check(t)
  455. tss, cleanup, err := newLoadBalancer(1)
  456. if err != nil {
  457. t.Fatalf("failed to create new load balancer: %v", err)
  458. }
  459. defer cleanup()
  460. be := &lbmpb.Server{
  461. IpAddress: tss.beIPs[0],
  462. Port: int32(tss.bePorts[0]),
  463. LoadBalanceToken: lbToken,
  464. DropForLoadBalancing: true,
  465. }
  466. var bes []*lbmpb.Server
  467. bes = append(bes, be)
  468. sl := &lbmpb.ServerList{
  469. Servers: bes,
  470. }
  471. tss.ls.sls = []*lbmpb.ServerList{sl}
  472. tss.ls.intervals = []time.Duration{0}
  473. creds := serverNameCheckCreds{
  474. expected: besn,
  475. }
  476. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  477. defer cancel()
  478. cc, err := grpc.DialContext(ctx, besn,
  479. grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})),
  480. grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
  481. if err != nil {
  482. t.Fatalf("Failed to dial to the backend %v", err)
  483. }
  484. defer cc.Close()
  485. testC := testpb.NewTestServiceClient(cc)
  486. ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond)
  487. defer cancel()
  488. if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
  489. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded)
  490. }
  491. }
  492. // When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
  493. func TestBalancerDisconnects(t *testing.T) {
  494. defer leakcheck.Check(t)
  495. var (
  496. lbAddrs []string
  497. lbs []*grpc.Server
  498. )
  499. for i := 0; i < 3; i++ {
  500. tss, cleanup, err := newLoadBalancer(1)
  501. if err != nil {
  502. t.Fatalf("failed to create new load balancer: %v", err)
  503. }
  504. defer cleanup()
  505. be := &lbmpb.Server{
  506. IpAddress: tss.beIPs[0],
  507. Port: int32(tss.bePorts[0]),
  508. LoadBalanceToken: lbToken,
  509. }
  510. var bes []*lbmpb.Server
  511. bes = append(bes, be)
  512. sl := &lbmpb.ServerList{
  513. Servers: bes,
  514. }
  515. tss.ls.sls = []*lbmpb.ServerList{sl}
  516. tss.ls.intervals = []time.Duration{0}
  517. lbAddrs = append(lbAddrs, tss.lbAddr)
  518. lbs = append(lbs, tss.lb)
  519. }
  520. creds := serverNameCheckCreds{
  521. expected: besn,
  522. }
  523. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  524. defer cancel()
  525. resolver := &testNameResolver{
  526. addrs: lbAddrs[:2],
  527. }
  528. cc, err := grpc.DialContext(ctx, besn,
  529. grpc.WithBalancer(grpc.NewGRPCLBBalancer(resolver)),
  530. grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
  531. if err != nil {
  532. t.Fatalf("Failed to dial to the backend %v", err)
  533. }
  534. defer cc.Close()
  535. testC := testpb.NewTestServiceClient(cc)
  536. var previousTrailer string
  537. trailer := metadata.MD{}
  538. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
  539. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  540. } else {
  541. previousTrailer = trailer[testmdkey][0]
  542. }
  543. // The initial resolver update contains lbs[0] and lbs[1].
  544. // When lbs[0] is stopped, lbs[1] should be used.
  545. lbs[0].Stop()
  546. for {
  547. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
  548. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  549. } else if trailer[testmdkey][0] != previousTrailer {
  550. // A new backend server should receive the request.
  551. // The trailer contains the backend address, so the trailer should be different from the previous one.
  552. previousTrailer = trailer[testmdkey][0]
  553. break
  554. }
  555. time.Sleep(100 * time.Millisecond)
  556. }
  557. // Inject a update to add lbs[2] to resolved addresses.
  558. resolver.inject([]*naming.Update{
  559. {Op: naming.Add,
  560. Addr: lbAddrs[2],
  561. Metadata: &naming.AddrMetadataGRPCLB{
  562. AddrType: naming.GRPCLB,
  563. ServerName: lbsn,
  564. },
  565. },
  566. })
  567. // Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used.
  568. lbs[1].Stop()
  569. for {
  570. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
  571. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  572. } else if trailer[testmdkey][0] != previousTrailer {
  573. // A new backend server should receive the request.
  574. // The trailer contains the backend address, so the trailer should be different from the previous one.
  575. break
  576. }
  577. time.Sleep(100 * time.Millisecond)
  578. }
  579. }
  580. type failPreRPCCred struct{}
  581. func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
  582. if strings.Contains(uri[0], "failtosend") {
  583. return nil, fmt.Errorf("rpc should fail to send")
  584. }
  585. return nil, nil
  586. }
  587. func (failPreRPCCred) RequireTransportSecurity() bool {
  588. return false
  589. }
  590. func checkStats(stats *lbmpb.ClientStats, expected *lbmpb.ClientStats) error {
  591. if !proto.Equal(stats, expected) {
  592. return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
  593. }
  594. return nil
  595. }
  596. func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbmpb.ClientStats {
  597. tss, cleanup, err := newLoadBalancer(3)
  598. if err != nil {
  599. t.Fatalf("failed to create new load balancer: %v", err)
  600. }
  601. defer cleanup()
  602. tss.ls.sls = []*lbmpb.ServerList{{
  603. Servers: []*lbmpb.Server{{
  604. IpAddress: tss.beIPs[2],
  605. Port: int32(tss.bePorts[2]),
  606. LoadBalanceToken: lbToken,
  607. DropForLoadBalancing: dropForLoadBalancing,
  608. DropForRateLimiting: dropForRateLimiting,
  609. }},
  610. }}
  611. tss.ls.intervals = []time.Duration{0}
  612. tss.ls.statsDura = 100 * time.Millisecond
  613. creds := serverNameCheckCreds{expected: besn}
  614. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  615. defer cancel()
  616. cc, err := grpc.DialContext(ctx, besn,
  617. grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})),
  618. grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}),
  619. grpc.WithBlock(), grpc.WithDialer(fakeNameDialer))
  620. if err != nil {
  621. t.Fatalf("Failed to dial to the backend %v", err)
  622. }
  623. defer cc.Close()
  624. runRPCs(cc)
  625. time.Sleep(1 * time.Second)
  626. tss.ls.mu.Lock()
  627. stats := tss.ls.stats
  628. tss.ls.mu.Unlock()
  629. return stats
  630. }
  631. const countRPC = 40
  632. func TestGRPCLBStatsUnarySuccess(t *testing.T) {
  633. defer leakcheck.Check(t)
  634. stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
  635. testC := testpb.NewTestServiceClient(cc)
  636. // The first non-failfast RPC succeeds, all connections are up.
  637. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
  638. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  639. }
  640. for i := 0; i < countRPC-1; i++ {
  641. testC.EmptyCall(context.Background(), &testpb.Empty{})
  642. }
  643. })
  644. if err := checkStats(&stats, &lbmpb.ClientStats{
  645. NumCallsStarted: int64(countRPC),
  646. NumCallsFinished: int64(countRPC),
  647. NumCallsFinishedKnownReceived: int64(countRPC),
  648. }); err != nil {
  649. t.Fatal(err)
  650. }
  651. }
  652. func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) {
  653. defer leakcheck.Check(t)
  654. c := 0
  655. stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
  656. testC := testpb.NewTestServiceClient(cc)
  657. for {
  658. c++
  659. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
  660. if strings.Contains(err.Error(), "drops requests") {
  661. break
  662. }
  663. }
  664. }
  665. for i := 0; i < countRPC; i++ {
  666. testC.EmptyCall(context.Background(), &testpb.Empty{})
  667. }
  668. })
  669. if err := checkStats(&stats, &lbmpb.ClientStats{
  670. NumCallsStarted: int64(countRPC + c),
  671. NumCallsFinished: int64(countRPC + c),
  672. NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
  673. NumCallsFinishedWithClientFailedToSend: int64(c - 1),
  674. }); err != nil {
  675. t.Fatal(err)
  676. }
  677. }
  678. func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) {
  679. defer leakcheck.Check(t)
  680. c := 0
  681. stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
  682. testC := testpb.NewTestServiceClient(cc)
  683. for {
  684. c++
  685. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
  686. if strings.Contains(err.Error(), "drops requests") {
  687. break
  688. }
  689. }
  690. }
  691. for i := 0; i < countRPC; i++ {
  692. testC.EmptyCall(context.Background(), &testpb.Empty{})
  693. }
  694. })
  695. if err := checkStats(&stats, &lbmpb.ClientStats{
  696. NumCallsStarted: int64(countRPC + c),
  697. NumCallsFinished: int64(countRPC + c),
  698. NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
  699. NumCallsFinishedWithClientFailedToSend: int64(c - 1),
  700. }); err != nil {
  701. t.Fatal(err)
  702. }
  703. }
  704. func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
  705. defer leakcheck.Check(t)
  706. stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
  707. testC := testpb.NewTestServiceClient(cc)
  708. // The first non-failfast RPC succeeds, all connections are up.
  709. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
  710. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  711. }
  712. for i := 0; i < countRPC-1; i++ {
  713. grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc)
  714. }
  715. })
  716. if err := checkStats(&stats, &lbmpb.ClientStats{
  717. NumCallsStarted: int64(countRPC),
  718. NumCallsFinished: int64(countRPC),
  719. NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
  720. NumCallsFinishedKnownReceived: 1,
  721. }); err != nil {
  722. t.Fatal(err)
  723. }
  724. }
  725. func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
  726. defer leakcheck.Check(t)
  727. stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
  728. testC := testpb.NewTestServiceClient(cc)
  729. // The first non-failfast RPC succeeds, all connections are up.
  730. stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
  731. if err != nil {
  732. t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
  733. }
  734. for {
  735. if _, err = stream.Recv(); err == io.EOF {
  736. break
  737. }
  738. }
  739. for i := 0; i < countRPC-1; i++ {
  740. stream, err = testC.FullDuplexCall(context.Background())
  741. if err == nil {
  742. // Wait for stream to end if err is nil.
  743. for {
  744. if _, err = stream.Recv(); err == io.EOF {
  745. break
  746. }
  747. }
  748. }
  749. }
  750. })
  751. if err := checkStats(&stats, &lbmpb.ClientStats{
  752. NumCallsStarted: int64(countRPC),
  753. NumCallsFinished: int64(countRPC),
  754. NumCallsFinishedKnownReceived: int64(countRPC),
  755. }); err != nil {
  756. t.Fatal(err)
  757. }
  758. }
  759. func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) {
  760. defer leakcheck.Check(t)
  761. c := 0
  762. stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
  763. testC := testpb.NewTestServiceClient(cc)
  764. for {
  765. c++
  766. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
  767. if strings.Contains(err.Error(), "drops requests") {
  768. break
  769. }
  770. }
  771. }
  772. for i := 0; i < countRPC; i++ {
  773. testC.FullDuplexCall(context.Background())
  774. }
  775. })
  776. if err := checkStats(&stats, &lbmpb.ClientStats{
  777. NumCallsStarted: int64(countRPC + c),
  778. NumCallsFinished: int64(countRPC + c),
  779. NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
  780. NumCallsFinishedWithClientFailedToSend: int64(c - 1),
  781. }); err != nil {
  782. t.Fatal(err)
  783. }
  784. }
  785. func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) {
  786. defer leakcheck.Check(t)
  787. c := 0
  788. stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
  789. testC := testpb.NewTestServiceClient(cc)
  790. for {
  791. c++
  792. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
  793. if strings.Contains(err.Error(), "drops requests") {
  794. break
  795. }
  796. }
  797. }
  798. for i := 0; i < countRPC; i++ {
  799. testC.FullDuplexCall(context.Background())
  800. }
  801. })
  802. if err := checkStats(&stats, &lbmpb.ClientStats{
  803. NumCallsStarted: int64(countRPC + c),
  804. NumCallsFinished: int64(countRPC + c),
  805. NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
  806. NumCallsFinishedWithClientFailedToSend: int64(c - 1),
  807. }); err != nil {
  808. t.Fatal(err)
  809. }
  810. }
  811. func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
  812. defer leakcheck.Check(t)
  813. stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
  814. testC := testpb.NewTestServiceClient(cc)
  815. // The first non-failfast RPC succeeds, all connections are up.
  816. stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
  817. if err != nil {
  818. t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
  819. }
  820. for {
  821. if _, err = stream.Recv(); err == io.EOF {
  822. break
  823. }
  824. }
  825. for i := 0; i < countRPC-1; i++ {
  826. grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend")
  827. }
  828. })
  829. if err := checkStats(&stats, &lbmpb.ClientStats{
  830. NumCallsStarted: int64(countRPC),
  831. NumCallsFinished: int64(countRPC),
  832. NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
  833. NumCallsFinishedKnownReceived: 1,
  834. }); err != nil {
  835. t.Fatal(err)
  836. }
  837. }