flink从kafka.data中拉取预测数据,从kafka.models中拉取训练好的模型。分别经过一些处理之后,如果只有一个模型进行推理预测,那么模型数据会以广播的形式发往下游算子与预测数据进行connect,如果是多个模型进行推理预测,就会以模型类型进行keyby,然后再进行connect进行预测推理。会以状态保存当前预测的模型和下一个需要进行预测的模型。
/** | |
* Convert byte array to ModelToServe. | |
* | |
* @param binary byte representation of ModelDescriptor.proto. | |
* @return model to serve. | |
*/ | |
public static Optional<ModelToServe> convertModel(byte[] binary){ | |
try { | |
// Unmarshall record | |
Modeldescriptor.ModelDescriptor model = Modeldescriptor.ModelDescriptor.parseFrom(binary); | |
// Return it | |
if (model.getMessageContentCase().equals(Modeldescriptor.ModelDescriptor.MessageContentCase.DATA)){ | |
return Optional.of(new ModelToServe( | |
model.getName(), model.getDescription(), model.getModeltype(), | |
model.getData().toByteArray(), null, model.getDataType())); | |
} | |
else { | |
return Optional.of(new ModelToServe( | |
model.getName(), model.getDescription(), model.getModeltype(), | |
null, model.getLocation(), model.getDataType())); | |
} | |
} catch (Throwable t) { | |
// Oops | |
System.out.println("Exception parsing input record" + new String(binary)); | |
t.printStackTrace(); | |
return Optional.empty(); | |
} | |
} | |
private static final Map<Integer, ModelFactory> factories = new HashMap<Integer, ModelFactory>(){
{
put(Modeldescriptor.ModelDescriptor.ModelType.TENSORFLOW.getNumber(), WineTensorflowModelFactory.getInstance());
put(Modeldescriptor.ModelDescriptor.ModelType.PMML.getNumber(), WinePMMLModelFactory.getInstance());
}
};
/**
* Get factory based on type.
*
* @param type model type.
* @return model factory.
*/
@Override
public ModelFactory getFactory(int type) {
return factories.get(type);
}
可以看到,目前仅仅实现了TensorFlow和pmml类型的model,根据modeltype来返回对应的ModelFactory实例。
/**
* Creates a new PMML model.
*
* @param descriptor model to serve representation of PMML model.
* @return model
*/
@Override
public Optional<Model> create(ModelToServe descriptor) {
try {
return Optional.of(new WinePMMLModel(descriptor.getModelData()));
}
catch (Throwable t){
System.out.println("Exception creating SpecificPMMLModel from " + descriptor);
t.printStackTrace();
return Optional.empty();
}
}
/**
* Creates a new tensorflow (optimized) model.
*
* @param descriptor model to serve representation of tensorflow model.
* @return model
*/
@Override
public Optional<Model> create(ModelToServe descriptor) {
try {
return Optional.of(new WineTensorflowModel(descriptor.getModelData()));
}
catch (Throwable t){
System.out.println("Exception creating SpecificTensorflowModel from " + descriptor);
t.printStackTrace();
return Optional.empty();
}
}
这两种类型的Factory都实现了create方法,用来创建对应的model实例。
/**
* Score data.
*
* @param input object to score.
*/
@Override
public Object score(Object input) {
// Convert input
Winerecord.WineRecord inputs = (Winerecord.WineRecord) input;
// Clear arguments
arguments.clear();
// Populate arguments with incoming data
for (InputField field : inputFields){
arguments.put(field.getName(), field.prepare(getValueByName(inputs, field.getName().getValue())));
}
// Calculate Output using PMML evaluator
Map<FieldName, ?> result = evaluator.evaluate(arguments);
// Prepare output
double rv = 0;
Object tresult = result.get(tname);
if (tresult instanceof Computable){
String value = ((Computable) tresult).getResult().toString();
rv = Double.parseDouble(value);
}
else {
rv = (Double) tresult;
}
return rv;
}
/**
* Score data.
*
* @param input object to score.
*/
@Override
public Object score(Object input) {
// Convert input data
Winerecord.WineRecord record = (Winerecord.WineRecord) input;
// Build input tensor
float[][] data = {{
(float) record.getFixedAcidity(),
(float) record.getVolatileAcidity(),
(float) record.getCitricAcid(),
(float) record.getResidualSugar(),
(float) record.getChlorides(),
(float) record.getFreeSulfurDioxide(),
(float) record.getTotalSulfurDioxide(),
(float) record.getDensity(),
(float) record.getPH(),
(float) record.getSulphates(),
(float) record.getAlcohol()
}};
Tensor modelInput = Tensor.create(data);
// Serve using tensorflow APIs
Tensor result = session.runner().feed("dense_1_input", modelInput).fetch("dense_3/Sigmoid").run().get(0);
// Convert result
long[] rshape = result.shape();
float[][] rMatrix = new float[(int) rshape[0]][(int) rshape[1]];
result.copyTo(rMatrix);
Intermediate value = new Intermediate(0, rMatrix[0][0]);
for (int i = 1; i < rshape[1]; i++){
if (rMatrix[0][i] > value.getValue()) {
value.setIndex(i);
value.setValue(rMatrix[0][i]);
}
}
return (double) value.getIndex();
}
每个model实例都实现了score方法,用来进行推理预测。可以根据实际需求进行重写。在connect之后,对connect流进行处理时,调用model的score方法,并将预测数据传入,就可以得到预测结果,并且会更新统计数据。
对于pmml格式的模型,需要多一步对pmml类型数据进行解析,使得Java可以调用。
/**
* Creates a new PMML model.
*
* @param input binary representation of PMML model.
*/
public PMMLModel(byte[] input) throws Throwable{
// Save bytes
bytes = input;
// unmarshal PMML
pmml = PMMLUtil.unmarshal(new ByteArrayInputStream(input));
// Optimize model
synchronized (this) {
for (Visitor optimizer : optimizers) {
try {
optimizer.applyTo(pmml);
} catch (Throwable t){
// Swallow it
}
}
}
目前flink model server仅仅实现了这两种类型,看起来是不会再进行进一步更新。根据实际需求,是可以在现有基础上进行二次开发,可以支持如pytorch等model。根据这套model server是可以实现成千上万个model一起推理预测的。
上述只是总体大概的一个处理流程,源码贴上:
https://github.com/apache/flink/tree/e9d32fd8eff6a45e38e1eee079a49589a21add15
文章转载自IT技术小输出,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。