Skip to content

Commit 4583ccf

Browse files
authored
SparkNLP - 995 Introducing MistralAI LLMs (#14318)
* added mistral * Mistral python API
1 parent 0ea5898 commit 4583ccf

File tree

7 files changed

+1358
-0
lines changed

7 files changed

+1358
-0
lines changed

python/sparknlp/annotator/seq2seq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from sparknlp.annotator.seq2seq.bart_transformer import *
2020
from sparknlp.annotator.seq2seq.llama2_transformer import *
2121
from sparknlp.annotator.seq2seq.m2m100_transformer import *
22+
from sparknlp.annotator.seq2seq.mistral_transformer import *
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
# Copyright 2017-2022 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Contains classes for the MistralTransformer."""
15+
16+
from sparknlp.common import *
17+
18+
19+
class MistralTransformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
20+
"""Mistral 7B
21+
22+
Mistral 7B, a 7.3 billion-parameter model that stands out for its efficient and effective
23+
performance in natural language processing. Surpassing Llama 2 13B across all benchmarks and
24+
excelling over Llama 1 34B in various aspects, Mistral 7B strikes a balance between English
25+
language tasks and code comprehension, rivaling the capabilities of CodeLlama 7B in the
26+
latter.
27+
28+
Mistral 7B introduces Grouped-query attention (GQA) for quicker inference, enhancing
29+
processing speed without compromising accuracy. This streamlined approach ensures a smoother
30+
user experience, making Mistral 7B a practical choice for real-world applications.
31+
32+
Additionally, Mistral 7B adopts Sliding Window Attention (SWA) to efficiently handle longer
33+
sequences at a reduced computational cost. This feature enhances the model's ability to
34+
process extensive textual input, expanding its utility in handling more complex tasks.
35+
36+
In summary, Mistral 7B represents a notable advancement in language models, offering a
37+
reliable and versatile solution for various natural language processing challenges.
38+
39+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
40+
object:
41+
42+
>>> mistral = MistralTransformer.pretrained() \\
43+
... .setInputCols(["document"]) \\
44+
... .setOutputCol("generation")
45+
46+
47+
The default model is ``"mistral-7b"``, if no name is provided. For available
48+
pretrained models please see the `Models Hub
49+
<https://sparknlp.org/models?q=mistral>`__.
50+
51+
====================== ======================
52+
Input Annotation types Output Annotation type
53+
====================== ======================
54+
``DOCUMENT`` ``DOCUMENT``
55+
====================== ======================
56+
57+
Parameters
58+
----------
59+
configProtoBytes
60+
ConfigProto from tensorflow, serialized into byte array.
61+
minOutputLength
62+
Minimum length of the sequence to be generated, by default 0
63+
maxOutputLength
64+
Maximum length of output text, by default 20
65+
doSample
66+
Whether or not to use sampling; use greedy decoding otherwise, by default False
67+
temperature
68+
The value used to module the next token probabilities, by default 1.0
69+
topK
70+
The number of highest probability vocabulary tokens to keep for
71+
top-k-filtering, by default 50
72+
topP
73+
Top cumulative probability for vocabulary tokens, by default 1.0
74+
75+
If set to float < 1, only the most probable tokens with probabilities
76+
that add up to ``topP`` or higher are kept for generation.
77+
repetitionPenalty
78+
The parameter for repetition penalty, 1.0 means no penalty. , by default
79+
1.0
80+
noRepeatNgramSize
81+
If set to int > 0, all ngrams of that size can only occur once, by
82+
default 0
83+
ignoreTokenIds
84+
A list of token ids which are ignored in the decoder's output, by
85+
default []
86+
87+
Notes
88+
-----
89+
This is a very computationally expensive module especially on larger
90+
sequence. The use of an accelerator such as GPU is recommended.
91+
92+
References
93+
----------
94+
- `Mistral 7B
95+
<https://mistral.ai/news/announcing-mistral-7b/>`__
96+
- https://github.com/mistralai/mistral-src
97+
98+
**Paper Abstract:**
99+
100+
*We introduce Mistral 7B v0.1, a 7-billion-parameter language model engineered for superior
101+
performance and efficiency. Mistral 7B outperforms Llama 2 13B across all evaluated
102+
benchmarks, and Llama 1 34B in reasoning, mathematics, and code generation. Our model
103+
leverages grouped-query attention (GQA) for faster inference, coupled with sliding window
104+
attention (SWA) to effectively handle sequences of arbitrary length with a reduced inference
105+
cost. We also provide a model fine-tuned to follow instructions, Mistral 7B -- Instruct, that
106+
surpasses the Llama 2 13B -- Chat model both on human and automated benchmarks. Our models are
107+
released under the Apache 2.0 license.*
108+
109+
Examples
110+
--------
111+
>>> import sparknlp
112+
>>> from sparknlp.base import *
113+
>>> from sparknlp.annotator import *
114+
>>> from pyspark.ml import Pipeline
115+
>>> documentAssembler = DocumentAssembler() \\
116+
... .setInputCol("text") \\
117+
... .setOutputCol("documents")
118+
>>> mistral = MistralTransformer.pretrained("mistral-7b") \\
119+
... .setInputCols(["documents"]) \\
120+
... .setMaxOutputLength(50) \\
121+
... .setOutputCol("generation")
122+
>>> pipeline = Pipeline().setStages([documentAssembler, mistral])
123+
>>> data = spark.createDataFrame([["My name is Leonardo."]]).toDF("text")
124+
>>> result = pipeline.fit(data).transform(data)
125+
>>> result.select("summaries.generation").show(truncate=False)
126+
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
127+
|result |
128+
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
129+
|[Leonardo Da Vinci invented the microscope?\n Question: Leonardo Da Vinci invented the microscope?\n Answer: No, Leonardo Da Vinci did not invent the microscope. The first microscope was invented |
130+
| in the late 16th century, long after Leonardo'] |
131+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
132+
"""
133+
134+
name = "MistralTransformer"
135+
136+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
137+
138+
outputAnnotatorType = AnnotatorType.DOCUMENT
139+
140+
141+
configProtoBytes = Param(Params._dummy(),
142+
"configProtoBytes",
143+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
144+
TypeConverters.toListInt)
145+
146+
minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
147+
typeConverter=TypeConverters.toInt)
148+
149+
maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
150+
typeConverter=TypeConverters.toInt)
151+
152+
doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
153+
typeConverter=TypeConverters.toBoolean)
154+
155+
temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
156+
typeConverter=TypeConverters.toFloat)
157+
158+
topK = Param(Params._dummy(), "topK",
159+
"The number of highest probability vocabulary tokens to keep for top-k-filtering",
160+
typeConverter=TypeConverters.toInt)
161+
162+
topP = Param(Params._dummy(), "topP",
163+
"If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
164+
typeConverter=TypeConverters.toFloat)
165+
166+
repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
167+
"The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
168+
typeConverter=TypeConverters.toFloat)
169+
170+
noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
171+
"If set to int > 0, all ngrams of that size can only occur once",
172+
typeConverter=TypeConverters.toInt)
173+
174+
ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds",
175+
"A list of token ids which are ignored in the decoder's output",
176+
typeConverter=TypeConverters.toListInt)
177+
178+
179+
def setIgnoreTokenIds(self, value):
180+
"""A list of token ids which are ignored in the decoder's output.
181+
182+
Parameters
183+
----------
184+
value : List[int]
185+
The words to be filtered out
186+
"""
187+
return self._set(ignoreTokenIds=value)
188+
189+
def setConfigProtoBytes(self, b):
190+
"""Sets configProto from tensorflow, serialized into byte array.
191+
192+
Parameters
193+
----------
194+
b : List[int]
195+
ConfigProto from tensorflow, serialized into byte array
196+
"""
197+
return self._set(configProtoBytes=b)
198+
199+
def setMinOutputLength(self, value):
200+
"""Sets minimum length of the sequence to be generated.
201+
202+
Parameters
203+
----------
204+
value : int
205+
Minimum length of the sequence to be generated
206+
"""
207+
return self._set(minOutputLength=value)
208+
209+
def setMaxOutputLength(self, value):
210+
"""Sets maximum length of output text.
211+
212+
Parameters
213+
----------
214+
value : int
215+
Maximum length of output text
216+
"""
217+
return self._set(maxOutputLength=value)
218+
219+
def setDoSample(self, value):
220+
"""Sets whether or not to use sampling, use greedy decoding otherwise.
221+
222+
Parameters
223+
----------
224+
value : bool
225+
Whether or not to use sampling; use greedy decoding otherwise
226+
"""
227+
return self._set(doSample=value)
228+
229+
def setTemperature(self, value):
230+
"""Sets the value used to module the next token probabilities.
231+
232+
Parameters
233+
----------
234+
value : float
235+
The value used to module the next token probabilities
236+
"""
237+
return self._set(temperature=value)
238+
239+
def setTopK(self, value):
240+
"""Sets the number of highest probability vocabulary tokens to keep for
241+
top-k-filtering.
242+
243+
Parameters
244+
----------
245+
value : int
246+
Number of highest probability vocabulary tokens to keep
247+
"""
248+
return self._set(topK=value)
249+
250+
def setTopP(self, value):
251+
"""Sets the top cumulative probability for vocabulary tokens.
252+
253+
If set to float < 1, only the most probable tokens with probabilities
254+
that add up to ``topP`` or higher are kept for generation.
255+
256+
Parameters
257+
----------
258+
value : float
259+
Cumulative probability for vocabulary tokens
260+
"""
261+
return self._set(topP=value)
262+
263+
def setRepetitionPenalty(self, value):
264+
"""Sets the parameter for repetition penalty. 1.0 means no penalty.
265+
266+
Parameters
267+
----------
268+
value : float
269+
The repetition penalty
270+
271+
References
272+
----------
273+
See `Ctrl: A Conditional Transformer Language Model For Controllable
274+
Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
275+
"""
276+
return self._set(repetitionPenalty=value)
277+
278+
def setNoRepeatNgramSize(self, value):
279+
"""Sets size of n-grams that can only occur once.
280+
281+
If set to int > 0, all ngrams of that size can only occur once.
282+
283+
Parameters
284+
----------
285+
value : int
286+
N-gram size can only occur once
287+
"""
288+
return self._set(noRepeatNgramSize=value)
289+
290+
@keyword_only
291+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.MistralTransformer", java_model=None):
292+
super(MistralTransformer, self).__init__(
293+
classname=classname,
294+
java_model=java_model
295+
)
296+
self._setDefault(
297+
minOutputLength=0,
298+
maxOutputLength=20,
299+
doSample=False,
300+
temperature=1,
301+
topK=50,
302+
topP=1,
303+
repetitionPenalty=1.0,
304+
noRepeatNgramSize=0,
305+
ignoreTokenIds=[],
306+
batchSize=1
307+
)
308+
309+
@staticmethod
310+
def loadSavedModel(folder, spark_session, use_openvino=False):
311+
"""Loads a locally saved model.
312+
313+
Parameters
314+
----------
315+
folder : str
316+
Folder of the saved model
317+
spark_session : pyspark.sql.SparkSession
318+
The current SparkSession
319+
320+
Returns
321+
-------
322+
MistralTransformer
323+
The restored model
324+
"""
325+
from sparknlp.internal import _MistralLoader
326+
jModel = _MistralLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
327+
return MistralTransformer(java_model=jModel)
328+
329+
@staticmethod
330+
def pretrained(name="mistral-7b", lang="en", remote_loc=None):
331+
"""Downloads and loads a pretrained model.
332+
333+
Parameters
334+
----------
335+
name : str, optional
336+
Name of the pretrained model, by default "mistral-7b"
337+
lang : str, optional
338+
Language of the pretrained model, by default "en"
339+
remote_loc : str, optional
340+
Optional remote address of the resource, by default None. Will use
341+
Spark NLPs repositories otherwise.
342+
343+
Returns
344+
-------
345+
MistralTransformer
346+
The restored model
347+
"""
348+
from sparknlp.pretrained import ResourceDownloader
349+
return ResourceDownloader.downloadModel(MistralTransformer, name, lang, remote_loc)

python/sparknlp/internal/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ def __init__(self, path, jspark):
276276
)
277277

278278

279+
class _MistralLoader(ExtendedJavaWrapper):
280+
def __init__(self, path, jspark, use_openvino=False):
281+
super(_MistralLoader, self).__init__(
282+
"com.johnsnowlabs.nlp.annotators.seq2seq.MistralTransformer.loadSavedModel", path, jspark, use_openvino)
283+
279284
class _MarianLoader(ExtendedJavaWrapper):
280285
def __init__(self, path, jspark):
281286
super(_MarianLoader, self).__init__(

0 commit comments

Comments
 (0)