diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index e029b1a..814938d 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "strconv" "strings" "sync" "time" @@ -235,30 +236,78 @@ func (c *cache) Increment(key string, value int64) (int64, error) { c.Lock() defer c.Unlock() - var val int64 + var val interface{} if _, err := c.Get(key, &val); err != nil && err != ErrNotFound { return 0, err } - val += value - if err := c.Set(key, val, time.Time{}); err != nil { - return val, err + + var counter int64 + + switch val.(type) { + case string: + // try to convert to number + a, err := strconv.ParseInt(val.(string), 10, 64) + if err != nil { + return 0, err + } + counter = a + case int64: + counter = val.(int64) + case int32: + counter = int64(val.(int32)) + case int: + counter = int64(val.(int)) + case nil: + counter = 0 + default: + return 0, errors.New("value is not an integer") } - return val, nil + + counter += value + + if err := c.Set(key, fmt.Sprintf("%v",counter), time.Time{}); err != nil { + return counter, err + } + return counter, nil } func (c *cache) Decrement(key string, value int64) (int64, error) { c.Lock() defer c.Unlock() - var val int64 + var val interface{} if _, err := c.Get(key, &val); err != nil && err != ErrNotFound { return 0, err } - val -= value - if err := c.Set(key, val, time.Time{}); err != nil { - return val, err + + var counter int64 + + switch val.(type) { + case string: + // try to convert to number + a, err := strconv.ParseInt(val.(string), 10, 64) + if err != nil { + return 0, err + } + counter = a + case int64: + counter = val.(int64) + case int32: + counter = int64(val.(int32)) + case int: + counter = int64(val.(int)) + case nil: + counter = 0 + default: + return 0, errors.New("value is not an integer") } - return val, nil + + counter -= value + + if err := c.Set(key, fmt.Sprintf("%v", counter), time.Time{}); err != nil { + return counter, err + } + return counter, nil } func (c *cache) ListKeys() ([]string, error) {