Browse Source

refactored to use Options.2DtmError

pull/328/head
yedf2 4 years ago
parent
commit
8f08358c7d
  1. 29
      dtmgrpc/workflow/imp.go
  2. 86
      dtmgrpc/workflow/utils.go
  3. 23
      dtmgrpc/workflow/workflow.go

29
dtmgrpc/workflow/imp.go

@ -36,10 +36,14 @@ func (wf *Workflow) loadProgresses() error {
if err == nil { if err == nil {
wf.progresses = map[string]*stepResult{} wf.progresses = map[string]*stepResult{}
for _, p := range progresses { for _, p := range progresses {
wf.progresses[p.BranchID+"-"+p.Op] = &stepResult{ sr := &stepResult{
Status: p.Status, Status: p.Status,
Data: p.BinData, Data: p.BinData,
} }
if sr.Status == dtmcli.StatusFailed {
sr.Error = fmt.Errorf("%s. %w", string(p.BinData), dtmcli.ErrFailure)
}
wf.progresses[p.BranchID+"-"+p.Op] = sr
} }
} }
return err return err
@ -71,6 +75,8 @@ func (w *workflowFactory) newWorkflow(name string, gid string, data []byte) *Wor
"data": data, "data": data,
}) })
wf.Context = context.WithValue(wf.Context, wfMeta{}, wf) wf.Context = context.WithValue(wf.Context, wfMeta{}, wf)
wf.Options.HTTPResp2DtmError = HTTPResp2DtmError
wf.Options.GRPCError2DtmError = dtmgrpc.GrpcError2DtmError
wf.initRestyClient() wf.initRestyClient()
return wf return wf
} }
@ -90,11 +96,7 @@ func (wf *Workflow) initRestyClient() {
old := wf.restyClient.GetClient().Transport old := wf.restyClient.GetClient().Transport
wf.restyClient.GetClient().Transport = newRoundTripper(old, wf) wf.restyClient.GetClient().Transport = newRoundTripper(old, wf)
wf.restyClient.OnAfterResponse(func(c *resty.Client, r *resty.Response) error { wf.restyClient.OnAfterResponse(func(c *resty.Client, r *resty.Response) error {
err := dtmimp.AfterResponse(c, r) return dtmimp.AfterResponse(c, r)
if err == nil && !wf.Options.DisalbeAutoError {
err = dtmimp.RespAsErrorCompatible(r) // check for dtm error
}
return err
}) })
} }
@ -119,10 +121,13 @@ func (wf *Workflow) process(handler WfFunc, data []byte) (err error) {
} }
func (wf *Workflow) saveResult(branchID string, op string, sr *stepResult) error { func (wf *Workflow) saveResult(branchID string, op string, sr *stepResult) error {
if sr.Status == "" { if sr.Status != "" {
return sr.Error err := wf.registerBranch(sr.Data, branchID, op, sr.Status)
if err != nil {
return err
}
} }
return wf.registerBranch(sr.Data, branchID, op, sr.Status) return sr.Error
} }
func (wf *Workflow) processPhase2(err error) error { func (wf *Workflow) processPhase2(err error) error {
@ -151,9 +156,9 @@ func (wf *Workflow) callPhase2(branchID string, fn WfPhase2Func) error {
if errors.Is(err, dtmcli.ErrFailure) { if errors.Is(err, dtmcli.ErrFailure) {
panic("should not return ErrFail in phase2") panic("should not return ErrFail in phase2")
} }
return stepResultFromLocal(nil, err) return wf.stepResultFromLocal(nil, err)
}) })
_, err := stepResultToLocal(r) _, err := wf.stepResultToLocal(r)
return err return err
} }
@ -187,7 +192,7 @@ func (wf *Workflow) recordedDoInner(fn func(bb *dtmcli.BranchBarrier) *stepResul
r = fn(bb) r = fn(bb)
err := wf.saveResult(branchID, wf.currentOp, r) err := wf.saveResult(branchID, wf.currentOp, r)
if err != nil { if err != nil {
r = stepResultFromLocal(nil, err) r = wf.stepResultFromLocal(nil, err)
} }
return r return r
} }

86
dtmgrpc/workflow/utils.go

@ -1,6 +1,7 @@
package workflow package workflow
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -10,8 +11,6 @@ import (
"github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli"
"github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/dtmimp"
"github.com/dtm-labs/dtm/dtmgrpc/dtmgimp" "github.com/dtm-labs/dtm/dtmgrpc/dtmgimp"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
) )
@ -58,21 +57,39 @@ func newJSONResponse(status int, result []byte) *http.Response {
func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
wf := r.wf wf := r.wf
origin := func(bb *dtmcli.BranchBarrier) *stepResult {
resp, err := r.old.RoundTrip(req)
return wf.stepResultFromHTTP(resp, err)
}
var sr *stepResult
if wf.currentOp != dtmimp.OpAction { // in phase 2, do not save, because it is saved outer if wf.currentOp != dtmimp.OpAction { // in phase 2, do not save, because it is saved outer
return r.old.RoundTrip(req) sr = origin(nil)
} else {
sr = wf.recordedDo(origin)
} }
sr := wf.recordedDo(func(bb *dtmcli.BranchBarrier) *stepResult { return wf.stepResultToHTTP(sr)
resp, err := r.old.RoundTrip(req)
return stepResultFromHTTP(resp, err)
})
return stepResultToHTTP(sr)
} }
func newRoundTripper(old http.RoundTripper, wf *Workflow) http.RoundTripper { func newRoundTripper(old http.RoundTripper, wf *Workflow) http.RoundTripper {
return &roundTripper{old: old, wf: wf} return &roundTripper{old: old, wf: wf}
} }
func stepResultFromLocal(data []byte, err error) *stepResult { // HTTPResp2DtmError check for dtm error and return it
func HTTPResp2DtmError(resp *http.Response) ([]byte, error) {
code := resp.StatusCode
data, err := ioutil.ReadAll(resp.Body)
resp.Body = ioutil.NopCloser(bytes.NewBuffer(data))
if code == http.StatusTooEarly {
return data, fmt.Errorf("%s. %w", string(data), dtmcli.ErrOngoing)
} else if code == http.StatusConflict {
return data, fmt.Errorf("%s. %w", string(data), dtmcli.ErrFailure)
} else if err == nil && code != http.StatusOK {
return data, errors.New(string(data))
}
return data, err
}
func (wf *Workflow) stepResultFromLocal(data []byte, err error) *stepResult {
return &stepResult{ return &stepResult{
Error: err, Error: err,
Status: wfErrorToStatus(err), Status: wfErrorToStatus(err),
@ -80,56 +97,41 @@ func stepResultFromLocal(data []byte, err error) *stepResult {
} }
} }
func stepResultToLocal(s *stepResult) ([]byte, error) { func (wf *Workflow) stepResultToLocal(sr *stepResult) ([]byte, error) {
if s.Error != nil { return sr.Data, sr.Error
return nil, s.Error
} else if s.Status == dtmcli.StatusFailed {
return nil, fmt.Errorf("%s. %w", string(s.Data), dtmcli.ErrFailure)
}
return s.Data, nil
} }
func stepResultFromGrpc(reply interface{}, err error) *stepResult { func (wf *Workflow) stepResultFromGrpc(reply interface{}, err error) *stepResult {
sr := &stepResult{} sr := &stepResult{Error: err}
st, ok := status.FromError(err)
if err == nil { if err == nil {
sr.Status = dtmcli.StatusSucceed sr.Error = wf.Options.GRPCError2DtmError(err)
sr.Data = dtmgimp.MustProtoMarshal(reply.(protoreflect.ProtoMessage)) sr.Status = wfErrorToStatus(sr.Error)
} else if ok && st.Code() == codes.Aborted { if sr.Error == nil {
sr.Status = dtmcli.StatusFailed sr.Data = dtmgimp.MustProtoMarshal(reply.(protoreflect.ProtoMessage))
sr.Data = []byte(st.Message()) } else if sr.Status == dtmcli.StatusFailed {
} else { sr.Data = []byte(sr.Error.Error())
sr.Error = err }
} }
return sr return sr
} }
func stepResultToGrpc(s *stepResult, reply interface{}) error { func (wf *Workflow) stepResultToGrpc(s *stepResult, reply interface{}) error {
if s.Error != nil { if s.Error == nil && s.Status == dtmcli.StatusSucceed {
return s.Error
} else if s.Status == dtmcli.StatusSucceed {
dtmgimp.MustProtoUnmarshal(s.Data, reply.(protoreflect.ProtoMessage)) dtmgimp.MustProtoUnmarshal(s.Data, reply.(protoreflect.ProtoMessage))
return nil
} }
return status.New(codes.Aborted, string(s.Data)).Err() return s.Error
} }
func stepResultFromHTTP(resp *http.Response, err error) *stepResult { func (wf *Workflow) stepResultFromHTTP(resp *http.Response, err error) *stepResult {
sr := &stepResult{Error: err} sr := &stepResult{Error: err}
if err == nil { if err == nil {
sr.Data, sr.Error = ioutil.ReadAll(resp.Body) sr.Data, sr.Error = wf.Options.HTTPResp2DtmError(resp)
if resp.StatusCode == http.StatusOK { sr.Status = wfErrorToStatus(sr.Error)
sr.Status = dtmcli.StatusSucceed
} else if resp.StatusCode == http.StatusConflict {
sr.Status = dtmcli.StatusFailed
} else {
sr.Error = errors.New(string(sr.Data))
}
} }
return sr return sr
} }
func stepResultToHTTP(s *stepResult) (*http.Response, error) { func (wf *Workflow) stepResultToHTTP(s *stepResult) (*http.Response, error) {
if s.Error != nil { if s.Error != nil {
return nil, s.Error return nil, s.Error
} }

23
dtmgrpc/workflow/workflow.go

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli"
@ -59,12 +60,16 @@ func ExecuteByQS(qs url.Values, body []byte) error {
// Options is for specifying workflow options // Options is for specifying workflow options
type Options struct { type Options struct {
// default false: Workflow's restyClient will convert http response to error if status code is not 200
// if this flag is set true, then Workflow's restyClient will keep the origin http response
DisalbeAutoError bool
// default false: fn registered by OnBranchRollback will not be called for FAILURE branch // Default: Code 409 => ErrFailure; Code 425 => ErrOngoing
// if this flag is set true, then fn will be called. the user should handle null rollback and dangling HTTPResp2DtmError func(*http.Response) ([]byte, error)
// Default: Code Aborted => ErrFailure; Code FailedPrecondition => ErrOngoing
GRPCError2DtmError func(error) error
// This Option specify whether a branch returning ErrFailure should be compensated on rollback.
// for most idempotent branches, no compensation is needed.
// But for a timeout request, the caller cannot know where the request is successful, so the compensation should be called
CompensateErrorBranch bool CompensateErrorBranch bool
} }
@ -147,9 +152,9 @@ func (wf *Workflow) OnBranchCommit(fn WfPhase2Func) *Workflow {
func (wf *Workflow) Do(fn func(bb *dtmcli.BranchBarrier) ([]byte, error)) ([]byte, error) { func (wf *Workflow) Do(fn func(bb *dtmcli.BranchBarrier) ([]byte, error)) ([]byte, error) {
res := wf.recordedDo(func(bb *dtmcli.BranchBarrier) *stepResult { res := wf.recordedDo(func(bb *dtmcli.BranchBarrier) *stepResult {
r, e := fn(bb) r, e := fn(bb)
return stepResultFromLocal(r, e) return wf.stepResultFromLocal(r, e)
}) })
return stepResultToLocal(res) return wf.stepResultToLocal(res)
} }
// DoXa will begin a local xa transaction // DoXa will begin a local xa transaction
@ -217,7 +222,7 @@ func Interceptor(ctx context.Context, method string, req, reply interface{}, cc
} }
sr := wf.recordedDo(func(bb *dtmcli.BranchBarrier) *stepResult { sr := wf.recordedDo(func(bb *dtmcli.BranchBarrier) *stepResult {
err := origin() err := origin()
return stepResultFromGrpc(reply, err) return wf.stepResultFromGrpc(reply, err)
}) })
return stepResultToGrpc(sr, reply) return wf.stepResultToGrpc(sr, reply)
} }

Loading…
Cancel
Save