Scala for的妙用

今天重溫scala的時候想起前同事教我的一招

當程式特定過程會依賴上一段過程的結果的時候可以用for傳遞

但前提是每個過程的output相同

這方法我第一次看到的時候蠻驚豔的,因為這種寫法才真正發揮scala的妙用

也是這種寫法讓我體會到functional programming的思維跟object-oriented programming的差異

格式大概如下

for( o1 <- function1;
     o2 <- function2(o1);
     o3 <- function3(o2);
     ....
 ) yield oN

最後輸出結果oN是經歷前方種種過程累積出來的精華

記錄一下今天重構時的結果

需求:我想將presto裏面catalog -> schema -> table所有的table列舉出來

最一開使就是暴力法分三段寫

val conn = DriverManager.getConnection(url, "test", null)

//handle catalogs
var catalogs:List[String] =  = List[String]()
var stat = conn.createStatement()
var rs = stat.executeQuery("SHOW catalogs")
while(rs.next()){
  catalogs = catalogs  :+ rs.getString(1)
}
rs.close();stat.close()

//handle schemas
var schemas:List[String] =  = List[String]()
stat = conn.createStatement()
for(c <- catalogs){
  var rs = stat.executeQuery(s"SHOW schemas FROM $c")
  while(rs.next()){
    schemas = schemas  :+ s"$c." + rs.getString(1)
  }
}
rs.close();stat.close()

//handle tables
var tables:List[String] =  = List[String]()
stat = conn.createStatement()
for(s <- schemas){
  var rs = stat.executeQuery(s"SHOW tables FROM $s")
  while(rs.next()){
    tables = tables  :+ s"$s." + rs.getString(1)
  }
}
rs.close();stat.close()

println(tables)

就這樣三個過程我最終得到了${catalog}.${schema}.${table}的列表輸出

後來進行重構的過程中,發現catalog,schema,table的三個過程最後輸出的都是List[String]

而且過程中是順序依賴,這時候就想到可以for將傳遞過程簡化

首先我先將query過程變成一個函式

 def query(conn:Connection, sql:String, prefix:String):List[String] = {
    var results:List[String] = List[String]()
    val stat = conn.createStatement()
    val rs = stat.executeQuery(sql)
    while(rs.next()){
      results = results  :+ (if(prefix.length > 0) s"$prefix." else "") + rs.getString(1)
    }
    rs.close();stat.close()
    results
  }

之後我的程式就能簡化成

  val tables = for (c <- query(conn, "Show catalogs", "");
                    s <- query(conn, s"Show schemas from $c", c);
                    t <- query(conn, s"Show tables from $s", s)
  ) yield t

用for..yield的特性得到輸出的最後結果

這是一個比較單純的範例,若用在實務上甚至可以跟future等物件做結合去實作一些比較複雜的演算法

譬如說想實作一個machine learning的過程,對一組資料進行預測

for( f <- extract(dataset);   //抽取資料集的feature
     t <- train(f, model);    //進行training
     p <- predict(t)) yield p //回傳預測結果
  

*上面演算法是憑印象寫的,不一定正確

這樣就完成了一個簡易演算法框架,只要稍微加工一下就能對每個演算法步驟進行模組化替換

comments powered by Disqus