Skip to content

Commit f6cc672

Browse files
authored
Merge pull request #4 from munisystem/master
Support multibyte characters
2 parents 5cd0009 + a307752 commit f6cc672

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

encoder.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ func loadFiles() (bpeData []byte, encoder []byte, err error) {
3333
}
3434

3535
// hardcoded
36-
func bytesToUnicode() map[rune]string {
37-
return map[rune]string{
36+
func bytesToUnicode() map[byte]string {
37+
return map[byte]string{
3838
0: "Ā",
3939
1: "ā",
4040
2: "Ă",
@@ -309,8 +309,8 @@ type Encoder struct {
309309
bpeRank map[lo.Tuple2[string, string]]int
310310
encoder map[string]int
311311
decoder map[int]string
312-
byteEncoder map[rune]string
313-
byteDecoder map[string]rune
312+
byteEncoder map[byte]string
313+
byteDecoder map[string]byte
314314

315315
cache sync.Map
316316
}
@@ -462,9 +462,9 @@ func (e *Encoder) Encode(text string) ([]int, error) {
462462
}
463463

464464
for _, match := range matches {
465-
runes := []rune(match)
465+
b := []byte(match)
466466

467-
token := strings.Join(lo.Map(runes, func(item rune, _ int) string {
467+
token := strings.Join(lo.Map(b, func(item byte, _ int) string {
468468
return e.byteEncoder[item]
469469
}), "")
470470

@@ -484,7 +484,7 @@ func (e *Encoder) Decode(tokens []int) string {
484484

485485
parts = lo.ChunkString(strings.Join(parts, ""), 1)
486486

487-
text := lo.Map(parts, func(item string, _ int) rune {
487+
text := lo.Map(parts, func(item string, _ int) byte {
488488
return e.byteDecoder[item]
489489
})
490490

encoder_test.go

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func TestBytesToUnicode(t *testing.T) {
1111
is := assert.New(t)
1212

1313
// Most useful test E.V.E.R ^^
14-
want := map[rune]string{
14+
want := map[byte]string{
1515
0: "Ā",
1616
1: "ā",
1717
2: "Ă",
@@ -454,11 +454,10 @@ func TestNewEncoder_encode(t *testing.T) {
454454
is.EqualValues(want, got)
455455
is.Nil(err)
456456

457-
// @TODO
458-
// want = []int{31373, 50169, 233, 995, 12520, 234, 235, 770, 318, 257, 890, 4731, 284, 1332, 1771, 393, 407, 262, 44805, 2071, 373, 5969, 0}
459-
// got, err = encoder.Encode("hello 👋 world 🌍 This is a long string to test whether or not the emoji issue was fixed!")
460-
// is.EqualValues(want, got)
461-
// is.Nil(err)
457+
want = []int{31373, 50169, 233, 995, 12520, 234, 235, 770, 318, 257, 890, 4731, 284, 1332, 1771, 393, 407, 262, 44805, 2071, 373, 5969, 0}
458+
got, err = encoder.Encode("hello 👋 world 🌍 This is a long string to test whether or not the emoji issue was fixed!")
459+
is.EqualValues(want, got)
460+
is.Nil(err)
462461
}
463462

464463
func TestNewEncoder_decode(t *testing.T) {
@@ -471,10 +470,9 @@ func TestNewEncoder_decode(t *testing.T) {
471470
got := encoder.Decode([]int{31373, 995, 770, 318, 257, 890, 4731, 284, 1332, 1771, 393, 407, 262, 44805, 2071, 373, 5969, 0})
472471
is.EqualValues(want, got)
473472

474-
// @TODO
475-
// want = "hello 👋 world 🌍 This is a long string to test whether or not the emoji issue was fixed!"
476-
// got = encoder.Decode([]int{31373, 50169, 233, 995, 12520, 234, 235, 770, 318, 257, 890, 4731, 284, 1332, 1771, 393, 407, 262, 44805, 2071, 373, 5969, 0})
477-
// is.EqualValues(want, got)
473+
want = "hello 👋 world 🌍 This is a long string to test whether or not the emoji issue was fixed!"
474+
got = encoder.Decode([]int{31373, 50169, 233, 995, 12520, 234, 235, 770, 318, 257, 890, 4731, 284, 1332, 1771, 393, 407, 262, 44805, 2071, 373, 5969, 0})
475+
is.EqualValues(want, got)
478476
}
479477

480478
func TestNewEncoder_e2e(t *testing.T) {
@@ -489,7 +487,8 @@ func TestNewEncoder_e2e(t *testing.T) {
489487
lo.T2("\t", []int{197}),
490488
lo.T2("This is some text", []int{1212, 318, 617, 2420}),
491489
lo.T2("indivisible", []int{521, 452, 12843}),
492-
// lo.T2("hello 👋 world 🌍", []int{31373, 50169, 233, 995, 12520, 234, 235}), // @TODO
490+
lo.T2("hello 👋 world 🌍", []int{31373, 50169, 233, 995, 12520, 234, 235}),
491+
lo.T2("hello, 世界", []int{31373, 11, 220, 10310, 244, 45911, 234}),
493492
}
494493

495494
for _, c := range cases {

0 commit comments

Comments
 (0)