Skip to content

Commit cdb031a

Browse files
SparkNLP 1043 integrate new casual lm annotators to use open vino (#14319)
* Phi2 scala api * Phi2 python api * Phi2 python and scala tests * Phi2 python and scala tests * added M2M100 openvino implementation * added phi2 openvino implementation * added openvino flag to python --------- Co-authored-by: Maziyar Panahi <maziyar.panahi@iscpif.fr>
1 parent 4583ccf commit cdb031a

File tree

14 files changed

+1692
-45
lines changed

14 files changed

+1692
-45
lines changed

python/sparknlp/annotator/seq2seq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
from sparknlp.annotator.seq2seq.bart_transformer import *
2020
from sparknlp.annotator.seq2seq.llama2_transformer import *
2121
from sparknlp.annotator.seq2seq.m2m100_transformer import *
22+
from sparknlp.annotator.seq2seq.phi2_transformer import *
2223
from sparknlp.annotator.seq2seq.mistral_transformer import *

python/sparknlp/annotator/seq2seq/m2m100_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.M2M100Tran
350350
tgtLang="fr")
351351

352352
@staticmethod
353-
def loadSavedModel(folder, spark_session):
353+
def loadSavedModel(folder, spark_session, use_openvino=False):
354354
"""Loads a locally saved model.
355355
356356
Parameters
@@ -366,7 +366,7 @@ def loadSavedModel(folder, spark_session):
366366
The restored model
367367
"""
368368
from sparknlp.internal import _M2M100Loader
369-
jModel = _M2M100Loader(folder, spark_session._jsparkSession)._java_obj
369+
jModel = _M2M100Loader(folder, spark_session._jsparkSession, use_openvino)._java_obj
370370
return M2M100Transformer(java_model=jModel)
371371

372372
@staticmethod
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
# Copyright 2017-2022 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Contains classes for the Phi2Transformer."""
15+
16+
from sparknlp.common import *
17+
18+
19+
class Phi2Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
20+
"""Phi-2: Textbooks Are All You Need.
21+
22+
Phi-2 is a Transformer with 2.7 billion parameters. It was trained using the same data sources as Phi-1.5,
23+
augmented with a new data source that consists of various NLP synthetic texts and filtered websites
24+
(for safety and educational value). When assessed against benchmarks testing common sense, language understanding,
25+
and logical reasoning, Phi-2 showcased a nearly state-of-the-art performance among models with less than 13 billion
26+
parameters.
27+
28+
Phi-2 hasn't been fine-tuned through reinforcement learning from human feedback. The intention behind crafting
29+
this open-source model is to provide the research community with a non-restricted small model to explore vital
30+
safety challenges, such as reducing toxicity, understanding societal biases, enhancing controllability, and more.
31+
32+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
33+
object:
34+
35+
>>> phi2 = Phi2Transformer.pretrained() \\
36+
... .setInputCols(["document"]) \\
37+
... .setOutputCol("generation")
38+
39+
40+
The default model is ``"llam2-7b"``, if no name is provided. For available
41+
pretrained models please see the `Models Hub
42+
<https://sparknlp.org/models?q=phi2>`__.
43+
44+
====================== ======================
45+
Input Annotation types Output Annotation type
46+
====================== ======================
47+
``DOCUMENT`` ``DOCUMENT``
48+
====================== ======================
49+
50+
Parameters
51+
----------
52+
configProtoBytes
53+
ConfigProto from tensorflow, serialized into byte array.
54+
minOutputLength
55+
Minimum length of the sequence to be generated, by default 0
56+
maxOutputLength
57+
Maximum length of output text, by default 20
58+
doSample
59+
Whether or not to use sampling; use greedy decoding otherwise, by default False
60+
temperature
61+
The value used to module the next token probabilities, by default 1.0
62+
topK
63+
The number of highest probability vocabulary tokens to keep for
64+
top-k-filtering, by default 50
65+
topP
66+
Top cumulative probability for vocabulary tokens, by default 1.0
67+
68+
If set to float < 1, only the most probable tokens with probabilities
69+
that add up to ``topP`` or higher are kept for generation.
70+
repetitionPenalty
71+
The parameter for repetition penalty, 1.0 means no penalty. , by default
72+
1.0
73+
noRepeatNgramSize
74+
If set to int > 0, all ngrams of that size can only occur once, by
75+
default 0
76+
ignoreTokenIds
77+
A list of token ids which are ignored in the decoder's output, by
78+
default []
79+
80+
Notes
81+
-----
82+
This is a very computationally expensive module especially on larger
83+
sequence. The use of an accelerator such as GPU is recommended.
84+
85+
References
86+
----------
87+
- `Phi-2: Textbooks Are All You Need.
88+
<https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/>`__
89+
- https://huggingface.co/microsoft/phi-2
90+
91+
**Paper Abstract:**
92+
93+
*In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned
94+
large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our
95+
fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models
96+
outperform open-source chat models on most benchmarks we tested, and based on our human
97+
evaluations for helpfulness and safety, may be a suitable substitute for closed-source models.
98+
We provide a detailed description of our approach to fine-tuning and safety improvements of
99+
Llama 2-Chat in order to enable the community to build on our work and contribute to the
100+
responsible development of LLMs.*
101+
102+
Examples
103+
--------
104+
>>> import sparknlp
105+
>>> from sparknlp.base import *
106+
>>> from sparknlp.annotator import *
107+
>>> from pyspark.ml import Pipeline
108+
>>> documentAssembler = DocumentAssembler() \\
109+
... .setInputCol("text") \\
110+
... .setOutputCol("documents")
111+
>>> phi2 = Phi2Transformer.pretrained("phi2-7b") \\
112+
... .setInputCols(["documents"]) \\
113+
... .setMaxOutputLength(50) \\
114+
... .setOutputCol("generation")
115+
>>> pipeline = Pipeline().setStages([documentAssembler, phi2])
116+
>>> data = spark.createDataFrame([["My name is Leonardo."]]).toDF("text")
117+
>>> result = pipeline.fit(data).transform(data)
118+
>>> result.select("summaries.generation").show(truncate=False)
119+
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
120+
|result |
121+
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
122+
|[My name is Leonardo . I am a student of the University of California, Berkeley. I am interested in the field of Artificial Intelligence and its applications in the real world. I have a strong |
123+
| passion for learning and am always looking for ways to improve my knowledge and skills] |
124+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
125+
"""
126+
127+
name = "Phi2Transformer"
128+
129+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
130+
131+
outputAnnotatorType = AnnotatorType.DOCUMENT
132+
133+
configProtoBytes = Param(Params._dummy(), "configProtoBytes",
134+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
135+
TypeConverters.toListInt)
136+
137+
minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
138+
typeConverter=TypeConverters.toInt)
139+
140+
maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
141+
typeConverter=TypeConverters.toInt)
142+
143+
doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
144+
typeConverter=TypeConverters.toBoolean)
145+
146+
temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
147+
typeConverter=TypeConverters.toFloat)
148+
149+
topK = Param(Params._dummy(), "topK",
150+
"The number of highest probability vocabulary tokens to keep for top-k-filtering",
151+
typeConverter=TypeConverters.toInt)
152+
153+
topP = Param(Params._dummy(), "topP",
154+
"If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
155+
typeConverter=TypeConverters.toFloat)
156+
157+
repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
158+
"The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
159+
typeConverter=TypeConverters.toFloat)
160+
161+
noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
162+
"If set to int > 0, all ngrams of that size can only occur once",
163+
typeConverter=TypeConverters.toInt)
164+
165+
ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds",
166+
"A list of token ids which are ignored in the decoder's output",
167+
typeConverter=TypeConverters.toListInt)
168+
169+
def setIgnoreTokenIds(self, value):
170+
"""A list of token ids which are ignored in the decoder's output.
171+
172+
Parameters
173+
----------
174+
value : List[int]
175+
The words to be filtered out
176+
"""
177+
return self._set(ignoreTokenIds=value)
178+
179+
def setConfigProtoBytes(self, b):
180+
"""Sets configProto from tensorflow, serialized into byte array.
181+
182+
Parameters
183+
----------
184+
b : List[int]
185+
ConfigProto from tensorflow, serialized into byte array
186+
"""
187+
return self._set(configProtoBytes=b)
188+
189+
def setMinOutputLength(self, value):
190+
"""Sets minimum length of the sequence to be generated.
191+
192+
Parameters
193+
----------
194+
value : int
195+
Minimum length of the sequence to be generated
196+
"""
197+
return self._set(minOutputLength=value)
198+
199+
def setMaxOutputLength(self, value):
200+
"""Sets maximum length of output text.
201+
202+
Parameters
203+
----------
204+
value : int
205+
Maximum length of output text
206+
"""
207+
return self._set(maxOutputLength=value)
208+
209+
def setDoSample(self, value):
210+
"""Sets whether or not to use sampling, use greedy decoding otherwise.
211+
212+
Parameters
213+
----------
214+
value : bool
215+
Whether or not to use sampling; use greedy decoding otherwise
216+
"""
217+
return self._set(doSample=value)
218+
219+
def setTemperature(self, value):
220+
"""Sets the value used to module the next token probabilities.
221+
222+
Parameters
223+
----------
224+
value : float
225+
The value used to module the next token probabilities
226+
"""
227+
return self._set(temperature=value)
228+
229+
def setTopK(self, value):
230+
"""Sets the number of highest probability vocabulary tokens to keep for
231+
top-k-filtering.
232+
233+
Parameters
234+
----------
235+
value : int
236+
Number of highest probability vocabulary tokens to keep
237+
"""
238+
return self._set(topK=value)
239+
240+
def setTopP(self, value):
241+
"""Sets the top cumulative probability for vocabulary tokens.
242+
243+
If set to float < 1, only the most probable tokens with probabilities
244+
that add up to ``topP`` or higher are kept for generation.
245+
246+
Parameters
247+
----------
248+
value : float
249+
Cumulative probability for vocabulary tokens
250+
"""
251+
return self._set(topP=value)
252+
253+
def setRepetitionPenalty(self, value):
254+
"""Sets the parameter for repetition penalty. 1.0 means no penalty.
255+
256+
Parameters
257+
----------
258+
value : float
259+
The repetition penalty
260+
261+
References
262+
----------
263+
See `Ctrl: A Conditional Transformer Language Model For Controllable
264+
Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
265+
"""
266+
return self._set(repetitionPenalty=value)
267+
268+
def setNoRepeatNgramSize(self, value):
269+
"""Sets size of n-grams that can only occur once.
270+
271+
If set to int > 0, all ngrams of that size can only occur once.
272+
273+
Parameters
274+
----------
275+
value : int
276+
N-gram size can only occur once
277+
"""
278+
return self._set(noRepeatNgramSize=value)
279+
280+
@keyword_only
281+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.Phi2Transformer", java_model=None):
282+
super(Phi2Transformer, self).__init__(classname=classname, java_model=java_model)
283+
self._setDefault(minOutputLength=0, maxOutputLength=20, doSample=False, temperature=0.6, topK=50, topP=0.9,
284+
repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], batchSize=1)
285+
286+
@staticmethod
287+
def loadSavedModel(folder, spark_session, use_openvino=False):
288+
"""Loads a locally saved model.
289+
290+
Parameters
291+
----------
292+
folder : str
293+
Folder of the saved model
294+
spark_session : pyspark.sql.SparkSession
295+
The current SparkSession
296+
297+
Returns
298+
-------
299+
Phi2Transformer
300+
The restored model
301+
"""
302+
from sparknlp.internal import _Phi2Loader
303+
jModel = _Phi2Loader(folder, spark_session._jsparkSession, use_openvino)._java_obj
304+
return Phi2Transformer(java_model=jModel)
305+
306+
@staticmethod
307+
def pretrained(name="phi2-7b", lang="en", remote_loc=None):
308+
"""Downloads and loads a pretrained model.
309+
310+
Parameters
311+
----------
312+
name : str, optional
313+
Name of the pretrained model, by default "phi2-7b"
314+
lang : str, optional
315+
Language of the pretrained model, by default "en"
316+
remote_loc : str, optional
317+
Optional remote address of the resource, by default None. Will use
318+
Spark NLPs repositories otherwise.
319+
320+
Returns
321+
-------
322+
Phi2Transformer
323+
The restored model
324+
"""
325+
from sparknlp.pretrained import ResourceDownloader
326+
return ResourceDownloader.downloadModel(Phi2Transformer, name, lang, remote_loc)

python/sparknlp/internal/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def __init__(self, path, jspark):
268268

269269

270270
class _M2M100Loader(ExtendedJavaWrapper):
271-
def __init__(self, path, jspark):
271+
def __init__(self, path, jspark, use_openvino=False):
272272
super(_M2M100Loader, self).__init__(
273273
"com.johnsnowlabs.nlp.annotators.seq2seq.M2M100Transformer.loadSavedModel",
274274
path,
@@ -279,7 +279,12 @@ def __init__(self, path, jspark):
279279
class _MistralLoader(ExtendedJavaWrapper):
280280
def __init__(self, path, jspark, use_openvino=False):
281281
super(_MistralLoader, self).__init__(
282-
"com.johnsnowlabs.nlp.annotators.seq2seq.MistralTransformer.loadSavedModel", path, jspark, use_openvino)
282+
"com.johnsnowlabs.nlp.annotators.seq2seq.MistralTransformer.loadSavedModel",
283+
path,
284+
jspark,
285+
use_openvino,
286+
)
287+
283288

284289
class _MarianLoader(ExtendedJavaWrapper):
285290
def __init__(self, path, jspark):
@@ -299,6 +304,16 @@ def __init__(self, path, jspark):
299304
)
300305

301306

307+
class _Phi2Loader(ExtendedJavaWrapper):
308+
def __init__(self, path, jspark, use_openvino=False):
309+
super(_Phi2Loader, self).__init__(
310+
"com.johnsnowlabs.nlp.annotators.seq2seq.Phi2Transformer.loadSavedModel",
311+
path,
312+
jspark,
313+
use_openvino,
314+
)
315+
316+
302317
class _RoBertaLoader(ExtendedJavaWrapper):
303318
def __init__(self, path, jspark, use_openvino=False):
304319
super(_RoBertaLoader, self).__init__(

0 commit comments

Comments
 (0)