diff --git a/encryption/spiral/spiral.go b/encryption/spiral/spiral.go index ae2fe3f..d42cae7 100644 --- a/encryption/spiral/spiral.go +++ b/encryption/spiral/spiral.go @@ -2,13 +2,14 @@ package spiral import ( - "crypto/aes" + stdaes "crypto/aes" "crypto/cipher" "errors" "fmt" "strings" "archgrid.xyz/ag/toolsbox/encryption" + "archgrid.xyz/ag/toolsbox/encryption/aes" "archgrid.xyz/ag/toolsbox/hash/sha512" verifyCode "archgrid.xyz/ag/toolsbox/random/verify_code" "archgrid.xyz/ag/toolsbox/serialize/base64" @@ -22,34 +23,59 @@ const ( ) // 根据给定的密钥字符串生成加解密使用的密钥。 -// 与Rust版本兼容:使用SHA512 hex字符串的字节表示。 func generateKey(key string) []byte { - hexStr := sha512.Sha512Hex([]byte(key)) - // 取hex字符串的第4-36字节(对应Rust版本) - return []byte(hexStr[4:36]) + keyBytes := sha512.Sha512([]byte(key)) + return keyBytes[4:36] +} + +// 使用原始密钥进行AES-CBC-256加密(不经过二次SHA256处理) +func encryptWithRawKey(data []byte, key []byte, ivGenerator aes.IVGenerator) ([]byte, error) { + block, err := stdaes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("创建加密单元失败,%w", err) + } + + var key32 [32]byte + copy(key32[:], key) + iv := ivGenerator(key32) + + plainText := encryption.Padding(data, block.BlockSize(), encryption.PKCS7Padding) + cipherText := make([]byte, len(plainText)) + mode := cipher.NewCBCEncrypter(block, iv[:]) + mode.CryptBlocks(cipherText, plainText) + + return cipherText, nil +} + +// 使用原始密钥进行AES-CBC-256解密(不经过二次SHA256处理) +func decryptWithRawKey(data []byte, key []byte, ivGenerator aes.IVGenerator) ([]byte, error) { + block, err := stdaes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("创建加密单元失败,%w", err) + } + + var key32 [32]byte + copy(key32[:], key) + iv := ivGenerator(key32) + + plainText := make([]byte, len(data)) + mode := cipher.NewCBCDecrypter(block, iv[:]) + mode.CryptBlocks(plainText, data) + + return encryption.Unpadding(plainText, encryption.PKCS7Padding), nil } // 对给定的数据进行加密。 func Encrypt(data string, strength ...Strength) (string, error) { + // 为了与Rust版本兼容,固定使用PrefixIVGenerator + ivGen := aes.PrefixIVGenerator key := verifyCode.RandStr(20) keyBytes := generateKey(key) - - // 直接使用crypto/aes,避免二次SHA256哈希 - block, err := aes.NewCipher(keyBytes) + // 直接使用keyBytes,不经过aes包的二次SHA256处理 + cipherData, err := encryptWithRawKey([]byte(data), keyBytes, ivGen) if err != nil { - return "", fmt.Errorf("创建加密单元失败,%w", err) + return "", fmt.Errorf("加密计算失败,%w", err) } - - // 使用key的前16字节作为IV(与Rust版本PrefixIVGenerator对应) - iv := keyBytes[:16] - - // PKCS7 padding - plainText := encryption.Padding([]byte(data), block.BlockSize(), encryption.PKCS7Padding) - - cipherData := make([]byte, len(plainText)) - mode := cipher.NewCBCEncrypter(block, iv) - mode.CryptBlocks(cipherData, plainText) - var result strings.Builder result.WriteString("[") result.WriteString(key) @@ -59,32 +85,21 @@ func Encrypt(data string, strength ...Strength) (string, error) { // 对给定的数据进行解密。 func Decrypt(data string, strength ...Strength) (string, error) { + // 为了与Rust版本兼容,固定使用PrefixIVGenerator + ivGen := aes.PrefixIVGenerator if message, found := strings.CutPrefix(data, "["); found { if len(message) > 20 { keySeed := message[:20] - keyBytes := generateKey(keySeed) - + key := generateKey(keySeed) cipherData, err := base64.FromBase64(message[20:]) if err != nil { return "", fmt.Errorf("密文损坏无法解析,%w", err) } - - // 直接使用crypto/aes,避免二次SHA256哈希 - block, err := aes.NewCipher(keyBytes) + // 直接使用key,不经过aes包的二次SHA256处理 + plainText, err := decryptWithRawKey(cipherData, key, ivGen) if err != nil { - return "", fmt.Errorf("创建加密单元失败,%w", err) + return "", fmt.Errorf("密文解密计算失败,%w", err) } - - // 使用key的前16字节作为IV(与Rust版本对应) - iv := keyBytes[:16] - - plainText := make([]byte, len(cipherData)) - mode := cipher.NewCBCDecrypter(block, iv) - mode.CryptBlocks(plainText, cipherData) - - // PKCS7 unpadding - plainText = encryption.Unpadding(plainText, encryption.PKCS7Padding) - return string(plainText), nil } return "", errors.New("密文缺损,无法完成解密。") diff --git a/encryption/spiral/spiral_test.go b/encryption/spiral/spiral_test.go new file mode 100644 index 0000000..4540530 --- /dev/null +++ b/encryption/spiral/spiral_test.go @@ -0,0 +1,15 @@ +package spiral + +import "testing" + +func TestDecode(t *testing.T) { + var origin = "[q3XvNHL7oTfVpHmZ2bOAnyVY/Q1Bm2dqsI8hfVA74R9CQb4vyksTD+Y9l4TT62o=" + decoded, err := Decrypt(origin) + if err != nil { + t.Fatalf("Decode failed: %v", err) + } + expected := "TmFRS0w6BIrAPA1Raj" + if decoded != expected { + t.Fatalf("Decoded value mismatch. Got: %s, Expected: %s", decoded, expected) + } +}