Skip to content

Commit 3054d4c

Browse files
authored
Fixed LLAMA generation bug (#14320)
* fixed LLAMA generation bug * update params
1 parent cdb031a commit 3054d4c

File tree

5 files changed

+37
-54
lines changed

5 files changed

+37
-54
lines changed

src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ private[johnsnowlabs] class LLAMA2(
164164
randomSeed,
165165
ignoreTokenIdsInt,
166166
session,
167-
applySoftmax = false,
167+
applySoftmax = true,
168168
ovInferRequest = ovInferRequest)
169169

170170
modelOutputs

src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,10 +392,7 @@ trait Generate {
392392
seededRandom = new scala.util.Random(seed.get)
393393
}
394394
for (i <- 0 until k) {
395-
var rand = scala.util.Random.nextDouble()
396-
if (seed.isDefined) {
397-
rand = new scala.util.Random(seed.get).nextDouble()
398-
}
395+
val rand = seededRandom.nextDouble()
399396
var cumProb = 0.0
400397
var j = 0
401398
while (j < probabilities.length - i) {

src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,25 @@ import scala.collection.mutable.ArrayBuffer
2020
class TopKLogitWarper(
2121
val k: Int,
2222
val filterValue: Float = Float.NegativeInfinity,
23-
val minTokensToKeep: Int = 1)
23+
val minTokensToKeep: Int = 100)
2424
extends LogitWarper {
2525
override def call(
2626
inputIds: Seq[Array[Int]],
2727
scores: Array[Array[Float]],
2828
currentLength: Int): Array[Array[Float]] = {
29-
var logitsUpd = scores
30-
val logitsShape = Array(scores.length, scores(0).length)
31-
if (k > 0) {
32-
val topKup = k.max(minTokensToKeep).min(logitsShape.last) // Safety check
29+
val logitsUpd = scores.map(_.clone()) // Deep copy of the scores
3330

34-
/** Remove all tokens with a probability less than the last token of the top-k */
31+
if (k > 0) {
32+
val logitsShape = Array(scores.length, scores.head.length)
33+
val effectiveTopK = k.max(minTokensToKeep).min(logitsShape.last) // Safety check
3534

36-
val topKLogits = new ArrayBuffer[Array[Float]]()
37-
for (logits <- scores) {
38-
val topKIndices = getTopKIndices(logits, topKup)
35+
for ((logits, i) <- scores.zipWithIndex) {
36+
val topKIndices = getTopKIndices(logits, effectiveTopK)
3937
val maskedValues = maskNotTopKValues(logits, topKIndices)
40-
topKLogits += maskedValues
38+
logitsUpd(i) = maskedValues
4139
}
42-
topKLogits.toArray
4340
}
41+
4442
logitsUpd
4543
}
4644

src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,46 +21,34 @@ class TopPLogitWarper(val p: Double, val minTokensToKeep: Int = 1) extends Logit
2121
inputIds: Seq[Array[Int]],
2222
scores: Array[Array[Float]],
2323
currentLength: Int): Array[Array[Float]] = {
24-
var scoresUpd = scores
25-
val scoresShape = Array(scores.length, scores(0).length)
26-
if (this.p < 1.0) {
27-
val (sortedscores, sortedIndices) = scores(0).zipWithIndex.sorted.reverse.unzip
24+
val logitsUpd = scores.map(_.clone()) // Deep copy of the scores
2825

29-
val cumulativeProbs = this.scanLeft(this.softmax(sortedscores))(0.0)(_ + _).drop(1)
26+
if (p < 1.0) {
27+
val scoresFiltered = scores.map(_.filterNot(_.isInfinite)) // Filter out infinite values
28+
val scoresShape = Array(scoresFiltered.length, scoresFiltered.head.length)
29+
val topPThreshold = math.ceil(p * scoresShape.last).toInt // Determine top-p threshold
3030

31-
/** Remove tokens with cumulative probability above the threshold (token with 0 are kept) */
32-
var sortedIndicesToRemove =
33-
for (prob <- cumulativeProbs)
34-
yield if (prob > this.p) true else false
35-
36-
if (minTokensToKeep > 1) {
37-
38-
/** Keep at least minTokensToKeep (set to minTokensToKeep-1 because we add the first one
39-
* below)
40-
*/
41-
sortedIndicesToRemove = List.fill(sortedIndicesToRemove.take(minTokensToKeep).length)(
42-
false) ++ sortedIndicesToRemove.drop(minTokensToKeep)
31+
for ((logits, i) <- scores.zipWithIndex) {
32+
val topPIndices = getTopPIndices(logits, topPThreshold)
33+
val maskedValues = maskNotTopPValues(logits, topPIndices)
34+
logitsUpd(i) = maskedValues
4335
}
36+
}
37+
38+
logitsUpd
39+
}
4440

45-
/** Shift the indices to the right to keep also the first token above the threshold */
46-
sortedIndicesToRemove = sortedIndicesToRemove.takeRight(1) ++ sortedIndicesToRemove
47-
.dropRight(1)
48-
sortedIndicesToRemove =
49-
List.fill(sortedIndicesToRemove.take(1).length)(false) ++ sortedIndicesToRemove
50-
.drop(1)
41+
private def getTopPIndices(logits: Array[Float], k: Int): Array[Int] = {
42+
logits.zipWithIndex.sortBy(-_._1).take(k).map(_._2)
43+
}
5144

52-
/** scatter sorted tensors to original indexing */
53-
val indicesToRemove =
54-
this.scatterValuesOnBatchIndices(sortedIndicesToRemove.toList, sortedIndices)
55-
scoresUpd =
56-
for ((nextTokenLogit, indexToRemove) <- scores.zip(
57-
IndexedSeq.fill(scores.length)(indicesToRemove)))
58-
yield setTensorByIndicesToValue(
59-
nextTokenLogit,
60-
indexToRemove.toIndexedSeq,
61-
Float.NegativeInfinity)
45+
private def maskNotTopPValues(logits: Array[Float], topPIndices: Array[Int]): Array[Float] = {
46+
val maskedValues = logits.clone()
47+
for (i <- logits.indices) {
48+
if (!topPIndices.contains(i)) {
49+
maskedValues(i) = Float.NegativeInfinity
50+
}
6251
}
63-
scoresUpd
52+
maskedValues
6453
}
65-
6654
}

src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,11 @@ class LLAMA2Transformer(override val uid: String)
227227
minOutputLength -> 0,
228228
maxOutputLength -> 20,
229229
doSample -> false,
230-
temperature -> 0.6,
231-
topK -> 50,
230+
temperature -> 0.9,
231+
topK -> 100,
232232
topP -> 0.9,
233233
repetitionPenalty -> 1.0,
234-
noRepeatNgramSize -> 3,
234+
noRepeatNgramSize -> 0,
235235
ignoreTokenIds -> Array(),
236236
batchSize -> 1,
237237
beamSize -> 1,

0 commit comments

Comments
 (0)