暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

flink model server 的实现

IT技术小输出 2021-04-11
358
flink已经是实时数据处理的代表,在bigdata + AI成为大趋势的情况下,flink早早就开源了FLIP-23来提供model server。但是flink model server的功能十分简陋,仅仅提供了一个简陋的框架和方向,目前仅仅支持TensorFlow的模型和pmml格式的模型。
其中pmml格式,就是将Python代码模型通过约定好的一个协议,转换成pmml格式的数据,再转成Java代码能调用的model,进行推理预测。但是这种跨语言的协议的使用不可避免的要牺牲一些性能,耗费一些资源。并且通过pmml转换的Java代码可调用的model,在预测结果上是会存在一定的偏差的,尽管偏差很小。

flink model server的实现:

目前的实现方式:

flink从kafka.data中拉取预测数据,从kafka.models中拉取训练好的模型。分别经过一些处理之后,如果只有一个模型进行推理预测,那么模型数据会以广播的形式发往下游算子与预测数据进行connect,如果是多个模型进行推理预测,就会以模型类型进行keyby,然后再进行connect进行预测推理。会以状态保存当前预测的模型和下一个需要进行预测的模型。


核心代码:
从kafka.models中拉取的model数据是以bye[]存在的,需要先进行反序列化,model server实现了一个核心方法,将byte[]转换为model数据:


/**

* Convert byte array to ModelToServe.

*

* @param binary byte representation of ModelDescriptor.proto.

* @return model to serve.

*/

public static Optional<ModelToServe> convertModel(byte[] binary){

try {

// Unmarshall record

Modeldescriptor.ModelDescriptor model = Modeldescriptor.ModelDescriptor.parseFrom(binary);

// Return it

if (model.getMessageContentCase().equals(Modeldescriptor.ModelDescriptor.MessageContentCase.DATA)){

return Optional.of(new ModelToServe(

model.getName(), model.getDescription(), model.getModeltype(),

model.getData().toByteArray(), null, model.getDataType()));

}

else {

return Optional.of(new ModelToServe(

model.getName(), model.getDescription(), model.getModeltype(),

null, model.getLocation(), model.getDataType()));

}

} catch (Throwable t) {

// Oops

System.out.println("Exception parsing input record" + new String(binary));

t.printStackTrace();

return Optional.empty();

}

}


其中,Modeldescriptor.ModelDescriptor是谷歌的protobuf序列化自动生成的。ModelToServe是描述model信息和保存数据的类。

从序列化数据转换为ModelToServer之后,会根据modelType,创建Factory,用来创建对应类型的model对象。
  private static final Map<Integer, ModelFactory> factories = new HashMap<Integer, ModelFactory>(){
{
put(Modeldescriptor.ModelDescriptor.ModelType.TENSORFLOW.getNumber(), WineTensorflowModelFactory.getInstance());
put(Modeldescriptor.ModelDescriptor.ModelType.PMML.getNumber(), WinePMMLModelFactory.getInstance());
}
};


/**
* Get factory based on type.
*
* @param type model type.
* @return model factory.
*/
@Override
public ModelFactory getFactory(int type) {
return factories.get(type);
}


        可以看到,目前仅仅实现了TensorFlow和pmml类型的model,根据modeltype来返回对应的ModelFactory实例。


/**
* Creates a new PMML model.
*
* @param descriptor model to serve representation of PMML model.
* @return model
*/
@Override
public Optional<Model> create(ModelToServe descriptor) {
try {
return Optional.of(new WinePMMLModel(descriptor.getModelData()));
}
catch (Throwable t){
System.out.println("Exception creating SpecificPMMLModel from " + descriptor);
t.printStackTrace();
return Optional.empty();
}
}
  /**
* Creates a new tensorflow (optimized) model.
*
* @param descriptor model to serve representation of tensorflow model.
* @return model
*/
@Override
public Optional<Model> create(ModelToServe descriptor) {
try {
return Optional.of(new WineTensorflowModel(descriptor.getModelData()));
}
catch (Throwable t){
System.out.println("Exception creating SpecificTensorflowModel from " + descriptor);
t.printStackTrace();
return Optional.empty();
}
}

这两种类型的Factory都实现了create方法,用来创建对应的model实例。

  /**
* Score data.
*
* @param input object to score.
*/
@Override
public Object score(Object input) {
// Convert input
Winerecord.WineRecord inputs = (Winerecord.WineRecord) input;
// Clear arguments
arguments.clear();
// Populate arguments with incoming data
for (InputField field : inputFields){
arguments.put(field.getName(), field.prepare(getValueByName(inputs, field.getName().getValue())));
}


// Calculate Output using PMML evaluator
Map<FieldName, ?> result = evaluator.evaluate(arguments);


// Prepare output
double rv = 0;
Object tresult = result.get(tname);
if (tresult instanceof Computable){
String value = ((Computable) tresult).getResult().toString();
rv = Double.parseDouble(value);
}
else {
rv = (Double) tresult;
}
return rv;
}
/**
* Score data.
*
* @param input object to score.
*/
@Override
public Object score(Object input) {
// Convert input data
Winerecord.WineRecord record = (Winerecord.WineRecord) input;
// Build input tensor
float[][] data = {{
(float) record.getFixedAcidity(),
(float) record.getVolatileAcidity(),
(float) record.getCitricAcid(),
(float) record.getResidualSugar(),
(float) record.getChlorides(),
(float) record.getFreeSulfurDioxide(),
(float) record.getTotalSulfurDioxide(),
(float) record.getDensity(),
(float) record.getPH(),
(float) record.getSulphates(),
(float) record.getAlcohol()
}};
Tensor modelInput = Tensor.create(data);
// Serve using tensorflow APIs
Tensor result = session.runner().feed("dense_1_input", modelInput).fetch("dense_3/Sigmoid").run().get(0);
// Convert result
long[] rshape = result.shape();
float[][] rMatrix = new float[(int) rshape[0]][(int) rshape[1]];
result.copyTo(rMatrix);
Intermediate value = new Intermediate(0, rMatrix[0][0]);
for (int i = 1; i < rshape[1]; i++){
if (rMatrix[0][i] > value.getValue()) {
value.setIndex(i);
value.setValue(rMatrix[0][i]);
}
}
return (double) value.getIndex();
}

每个model实例都实现了score方法,用来进行推理预测。可以根据实际需求进行重写。在connect之后,对connect流进行处理时,调用model的score方法,并将预测数据传入,就可以得到预测结果,并且会更新统计数据。


对于pmml格式的模型,需要多一步对pmml类型数据进行解析,使得Java可以调用。

/**
* Creates a new PMML model.
*
* @param input binary representation of PMML model.
*/
public PMMLModel(byte[] input) throws Throwable{
// Save bytes
bytes = input;
// unmarshal PMML
pmml = PMMLUtil.unmarshal(new ByteArrayInputStream(input));
// Optimize model
synchronized (this) {
for (Visitor optimizer : optimizers) {
try {
optimizer.applyTo(pmml);
} catch (Throwable t){
// Swallow it
}
}
}


目前flink model server仅仅实现了这两种类型,看起来是不会再进行进一步更新。根据实际需求,是可以在现有基础上进行二次开发,可以支持如pytorch等model。根据这套model server是可以实现成千上万个model一起推理预测的。

上述只是总体大概的一个处理流程,源码贴上:


https://github.com/apache/flink/tree/e9d32fd8eff6a45e38e1eee079a49589a21add15

文章转载自IT技术小输出,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论